def prune_bn2d(module: BatchNorm2d, keep_idxes):
    module.num_features = len(keep_idxes)
    module.weight.data = module.weight.data[keep_idxes]
    module.weight.grad = None
    module.bias.data = module.bias.data[keep_idxes]
    module.bias.grad = None
    module.running_mean = module.running_mean[keep_idxes]
    module.running_var = module.running_var[keep_idxes]
Ejemplo n.º 2
0
def prune_conv_layer(
        conv_layer: Union[nn.Conv2d, nn.Linear],
        bn_layer: nn.BatchNorm2d,
        sparse_layer: Union[nn.BatchNorm2d, SparseGate],
        in_channel_mask: np.ndarray,
        prune_output_mode: str,
        pruner: Callable[[np.ndarray], float],
        prune_mode: str,
        sparse_layer_in: typing.Optional[SparseGate] = None,
        prune_on="factor") -> typing.Tuple[np.ndarray, np.ndarray]:
    """
    Note: if the sparse_layer is SparseGate, the gate will be replaced by BatchNorm
    scaling factor. The value of the gate will be set to all ones.

    :param prune_output_mode: how to handle the output channel (case insensitive)
        "keep": keep the output channel intact
        "same": keep the output channel as same as input channel
        "prune": prune the output according to bn_layer
    :param pruner: the method to determinate the pruning threshold.
    :param in_channel_mask: a 0-1 vector indicates whether the corresponding channel should be pruned (0) or not (1)
    :param bn_layer: the BatchNorm layer after the convolution layer
    :param conv_layer: the convolution layer to be pruned
    :param sparse_layer: the layer to determine the sparsity. Support BatchNorm2d and SparseGate.
    :param prune_mode: pruning mode (`str`), case-insensitive:
        - `"multiply"`: pruning threshold is determined by the multiplication of `sparse_layer` and `bn_layer`
            only available when `sparse_layer` is `SparseGate`
        - `None` or `"default"`: default behaviour. The pruning threshold is determined by `sparse_layer`
    :param prune_on: 'factor' or 'weight'.
    :param sparse_layer_in: the layer to determine the input sparsity. Support BatchNorm2d and SparseGate.
        When the `sparse_layer_in` is None, there is no input sparse layers,
        the input channel will be determined by the `in_channel_mask`
        Note: `in_channel_mask` is CONFLICT with `sparse_layer_in`!
    :return out_channel_mask
    """
    assert isinstance(conv_layer, nn.Conv2d) or isinstance(
        conv_layer, nn.Linear), f"conv_layer got {conv_layer}"

    assert isinstance(sparse_layer, nn.BatchNorm2d) or \
           isinstance(sparse_layer, nn.BatchNorm1d) or isinstance(sparse_layer, SparseGate), \
        f"sparse_layer got {sparse_layer}"

    if in_channel_mask is not None and sparse_layer_in is not None:
        raise ValueError(
            "Conflict option: in_channel_mask and sparse_layer_in")

    prune_mode = prune_mode.lower()  #lower() 方法转换字符串中所有大写字符为小写
    prune_output_mode = str.lower(
        prune_output_mode)  #lower() 方法转换字符串中所有大写字符为小写

    if prune_mode == 'multiply':
        if bn_layer is None:
            raise ValueError("Could not use multiply mode when bn is None")
        if not isinstance(sparse_layer, SparseGate):
            raise ValueError(
                f"Do not support prune_mode {prune_mode} when the sparse_layer is {sparse_layer}"
            )

    with torch.no_grad():
        conv_weight: torch.Tensor = conv_layer.weight.data.clone()

        # prune the input channel of the conv layer
        # if sparse_layer_in and in_channel_mask are both None, the input dim will NOT be pruned
        if sparse_layer_in is not None:
            if in_channel_mask is not None:
                raise ValueError("")
            sparse_weight_in: np.ndarray = sparse_layer_in.weight.view(
                -1).data.cpu().numpy()
            # the in_channel_mask will be overwrote
            input_threshold = pruner(sparse_weight_in)
            in_channel_mask: np.ndarray = sparse_weight_in > input_threshold  #获得掩码

        # convert mask to channel indexes
        idx_in = np.squeeze(np.argwhere(np.asarray(
            in_channel_mask)))  #np.argwhere(a):返回非0的数组元组的索引,其中a是要索引数组的条件。
        if len(idx_in.shape) == 0:
            # expand the single scalar to array
            idx_in = np.expand_dims(idx_in, 0)

        # prune the input of the conv layer
        if isinstance(conv_layer, nn.Conv2d):
            if conv_layer.groups == 1:
                a = conv_weight
                conv_weight = conv_weight[:, idx_in.tolist(
                ), :, :]  #tolist():将数组或者矩阵转换成列表
                print(a - conv_weight)
            else:
                assert conv_weight.shape[
                    1] == 1, "only works for groups == num_channels"
        elif isinstance(conv_layer, nn.Linear):
            conv_weight = conv_weight[:, idx_in.tolist()]
        else:
            raise ValueError(f"unsupported conv layer type: {conv_layer}")

        # prune the output channel of the conv layer
        if prune_output_mode == "prune":
            if prune_on == 'factor':
                # the sparse_layer.weight need to be flatten, because the weight of SparseGate is not 1d
                sparse_weight: np.ndarray = sparse_layer.weight.view(
                    -1).data.cpu().numpy()
                if prune_mode == 'multiply':
                    bn_weight = bn_layer.weight.data.cpu().numpy()
                    sparse_weight = sparse_weight * bn_weight  # element-wise multiplication
                    pass
                elif prune_mode != 'default':
                    raise ValueError(f"Do not support prune_mode {prune_mode}")

                # prune according the bn layer
                output_threshold = pruner(sparse_weight)
                out_channel_mask: np.ndarray = sparse_weight > output_threshold  #获取掩码
            else:
                sparse_weight: np.ndarray = sparse_layer.weight.view(
                    -1).data.cpu().numpy()
                # in this case, the sparse weight should be the conv or linear weight
                out_channel_mask: np.ndarray = pruner(
                    conv_weight.data.cpu().numpy())

        elif prune_output_mode == "keep":
            # do not prune the output
            out_channel_mask = np.ones(conv_layer.out_channels)
        elif prune_output_mode == "same":
            # prune the output channel with the input mask
            # keep the conv layer in_channel == out_channel
            out_channel_mask = in_channel_mask
        else:
            raise ValueError(f"invalid prune_output_mode: {prune_output_mode}")

        if not np.any(out_channel_mask):  #NumPy any() 判断矩阵中 是否 有一个元素 为True
            # there is no channel left
            return out_channel_mask, in_channel_mask

        idx_out: np.ndarray = np.squeeze(
            np.argwhere(np.asarray(out_channel_mask)))
        if len(idx_out.shape) == 0:
            # 0-d scalar
            idx_out = np.expand_dims(idx_out, 0)

        if isinstance(conv_layer, nn.Conv2d):
            conv_weight = conv_weight[idx_out.tolist(), :, :, :]  #获得输出权值
        elif isinstance(conv_layer, nn.Linear):
            conv_weight = conv_weight[idx_out.tolist(), :]
            linear_bias = conv_layer.bias.clone()
            linear_bias = linear_bias[idx_out.tolist()]
        else:
            raise ValueError(f"unsupported conv layer type: {conv_layer}")

        # change the property of the conv layer
        if isinstance(conv_layer, nn.Conv2d):
            conv_layer.in_channels = len(idx_in)
            conv_layer.out_channels = len(idx_out)
        elif isinstance(conv_layer, nn.Linear):
            conv_layer.in_features = len(idx_in)
            conv_layer.out_features = len(idx_out)
        conv_layer.weight.data = conv_weight
        if isinstance(conv_layer, nn.Linear):
            conv_layer.bias.data = linear_bias
        if isinstance(conv_layer, nn.Conv2d) and conv_layer.groups != 1:
            # set the new groups for dw layer (for MobileNet)
            conv_layer.groups = conv_layer.in_channels
            pass

        # prune the bn layer
        if bn_layer is not None:
            bn_layer.weight.data = bn_layer.weight.data[
                idx_out.tolist()].clone()
            bn_layer.bias.data = bn_layer.bias.data[
                idx_out.tolist()].clone()  ##########
            bn_layer.running_mean = bn_layer.running_mean[
                idx_out.tolist()].clone()
            bn_layer.running_var = bn_layer.running_var[
                idx_out.tolist()].clone()

            # set bn properties
            bn_layer.num_features = len(idx_out)

        # prune the gate
        if isinstance(sparse_layer, SparseGate):
            sparse_layer.prune(idx_out)
            # multiply the bn weight and SparseGate weight
            sparse_weight: torch.Tensor = sparse_layer.weight.view(-1)
            if bn_layer is not None:
                bn_layer.weight.data = (bn_layer.weight.data *
                                        sparse_weight).clone()
                bn_layer.bias.data = (bn_layer.bias.data *
                                      sparse_weight).clone()
            # the function of the SparseGate is now replaced by bn layers
            # the SparseGate should be disabled
            sparse_layer.set_ones()

    return out_channel_mask, in_channel_mask
Ejemplo n.º 3
0
def prune_conv_layer(
    conv_layer: nn.Conv2d,
    bn_layer: nn.BatchNorm2d,
    sparse_layer_out: typing.Optional[Union[nn.BatchNorm2d, SparseGate]],
    in_channel_mask: typing.Optional[np.ndarray],
    prune_output_mode: str,
    pruner: Pruner,
    prune_mode: str,
    sparse_layer_in: typing.Optional[Union[nn.BatchNorm2d, SparseGate]],
) -> typing.Tuple[np.ndarray, np.ndarray]:
    """
    Note: if the sparse_layer is SparseGate, the gate will be replaced by BatchNorm
    scaling factor. The value of the gate will be set to all ones.

    :param prune_output_mode: how to handle the output channel (case insensitive)
        "keep": keep the output channel intact
        "same": keep the output channel as same as input channel
        "prune": prune the output according to bn_layer
    :param pruner: the method to determinate the pruning mask.
    :param in_channel_mask: a 0-1 vector indicates whether the corresponding channel should be pruned (0) or not (1)
    :param bn_layer: the BatchNorm layer after the convolution layer
    :param conv_layer: the convolution layer to be pruned

    :param sparse_layer_out: the layer to determine the output sparsity. Support BatchNorm2d and SparseGate.
    :param sparse_layer_in: the layer to determine the input sparsity. Support BatchNorm2d and SparseGate.
        When the `sparse_layer_in` is None, there is no input sparse layers,
        the input channel will be determined by the `in_channel_mask`
        Note: `in_channel_mask` is CONFLICT with `sparse_layer_in`!

    :param prune_mode: pruning mode (`str`):
        - `"multiply"`: pruning threshold is determined by the multiplication of `sparse_layer` and `bn_layer`
            only available when `sparse_layer` is `SparseGate`
        - `None` or `"default"`: default behaviour. The pruning threshold is determined by `sparse_layer`
    :return out_channel_mask
    """
    assert isinstance(conv_layer, nn.Conv2d), f"conv_layer got {conv_layer}"

    assert isinstance(sparse_layer_out, nn.BatchNorm2d) or isinstance(
        sparse_layer_out, SparseGate), f"sparse_layer got {sparse_layer_out}"
    if in_channel_mask is not None and sparse_layer_in is not None:
        raise ValueError(
            "Conflict option: in_channel_mask and sparse_layer_in")

    prune_mode = prune_mode.lower()
    prune_output_mode = str.lower(prune_output_mode)

    if prune_mode == 'multiply':
        if not isinstance(sparse_layer_out, SparseGate):
            raise ValueError(
                f"Do not support prune_mode {prune_mode} when the sparse_layer is {sparse_layer_out}"
            )

    with torch.no_grad():
        conv_weight: torch.Tensor = conv_layer.weight.data.clone()

        # prune the input channel of the conv layer
        # if sparse_layer_in and in_channel_mask are both None, the input dim will NOT be pruned
        if sparse_layer_in is not None:
            if in_channel_mask is not None:
                raise ValueError("")
            sparse_weight_in: np.ndarray = sparse_layer_in.weight.view(
                -1).data.cpu().numpy()
            # the in_channel_mask will be overwrote
            in_channel_mask = pruner(sparse_weight_in)

        if in_channel_mask is not None:
            # prune the input channel according to the in_channel_mask
            # convert mask to channel indexes
            idx_in = np.squeeze(np.argwhere(np.asarray(in_channel_mask)))
            if len(idx_in.shape) == 0:
                # expand the single scalar to array
                idx_in = np.expand_dims(idx_in, 0)
            elif len(idx_in.shape) == 1 and idx_in.shape[0] == 0:
                # nothing left, prune the whole block
                out_channel_mask = np.full(conv_layer.out_channels, False)
                return in_channel_mask, out_channel_mask

        # prune the input of the conv layer
        if conv_layer.groups == 1:
            conv_weight = conv_weight[:, idx_in.tolist(), :, :]
        else:
            assert conv_weight.shape[
                1] == 1, "only works for groups == num_channels"

        # prune the output channel of the conv layer
        if prune_output_mode == "prune":
            # the sparse_layer.weight need to be flatten, because the weight of SparseGate is not 1d
            sparse_weight_out: np.ndarray = sparse_layer_out.weight.view(
                -1).data.cpu().numpy()
            if prune_mode == 'multiply':
                bn_weight = bn_layer.weight.data.cpu().numpy()
                sparse_weight_out = sparse_weight_out * bn_weight  # element-wise multiplication
            elif prune_mode != 'default':
                raise ValueError(f"Do not support prune_mode {prune_mode}")

            # prune and get the pruned mask
            out_channel_mask = pruner(sparse_weight_out)
        elif prune_output_mode == "keep":
            # do not prune the output
            out_channel_mask = np.ones(conv_layer.out_channels)
        elif prune_output_mode == "same":
            # prune the output channel with the input mask
            # keep the conv layer in_channel == out_channel
            out_channel_mask = in_channel_mask
        else:
            raise ValueError(f"invalid prune_output_mode: {prune_output_mode}")

        idx_out: np.ndarray = np.squeeze(
            np.argwhere(np.asarray(out_channel_mask)))
        if len(idx_out.shape) == 0:
            # expand the single scalar to array
            idx_out = np.expand_dims(idx_out, 0)
        elif len(idx_out.shape) == 1 and idx_out.shape[0] == 0:
            # no channel left
            # return mask directly
            # the block is supposed to be set as a identity mapping
            return out_channel_mask, in_channel_mask
        conv_weight = conv_weight[idx_out.tolist(), :, :, :]

        # change the property of the conv layer
        conv_layer.in_channels = len(idx_in)
        conv_layer.out_channels = len(idx_out)
        conv_layer.weight.data = conv_weight
        if conv_layer.groups != 1:
            # set the new groups for dw layer
            conv_layer.groups = conv_layer.in_channels
            pass

        # prune the bn layer
        bn_layer.weight.data = bn_layer.weight.data[idx_out.tolist()].clone()
        bn_layer.bias.data = bn_layer.bias.data[idx_out.tolist()].clone()
        bn_layer.running_mean = bn_layer.running_mean[idx_out.tolist()].clone()
        bn_layer.running_var = bn_layer.running_var[idx_out.tolist()].clone()

        # set bn properties
        bn_layer.num_features = len(idx_out)

        # prune the gate
        if isinstance(sparse_layer_out, SparseGate):
            sparse_layer_out.prune(idx_out)
            # multiply the bn weight and SparseGate weight
            sparse_weight_out: torch.Tensor = sparse_layer_out.weight.view(-1)
            bn_layer.weight.data = (bn_layer.weight.data *
                                    sparse_weight_out).clone()
            bn_layer.bias.data = (bn_layer.bias.data *
                                  sparse_weight_out).clone()
            # the function of the SparseGate is now replaced by bn layers
            # the SparseGate should be disabled
            sparse_layer_out.set_ones()

    return out_channel_mask, in_channel_mask