Beispiel #1
0
def prune_conv2d_layer_with_craig(layer: nn.Conv2d,
                                  prune_percent_per_layer: float,
                                  similarity_metric: Union[Text, Dict] = "",
                                  prune_type: Text = "craig",
                                  **kwargs) -> Tuple[List[int], List[float]]:

    # Get CRAIG subset.
    subset_nodes: List
    subset_weights: List
    subset_nodes, subset_weights = get_layer_craig_subset(
        layer=layer,
        original_num_nodes=layer.out_channels,
        prune_percent_per_layer=prune_percent_per_layer,
        similarity_metric=similarity_metric,
        prune_type=prune_type,
        **kwargs)

    # Remove nodes+weights+biases, and adjust weights.
    num_nodes: int = len(subset_nodes)

    # Prune current layer.
    # Multiply weights (and biases?) by subset_weights.
    subset_weights_tensor = torch.tensor(subset_weights)
    layer.weight = nn.Parameter(layer.weight[subset_nodes] *
                                subset_weights_tensor.reshape(
                                    (num_nodes, 1, 1, 1)))
    if layer.bias is not None:
        layer.bias = nn.Parameter(layer.bias[subset_nodes] *
                                  subset_weights_tensor)
    layer.out_channels = num_nodes

    return subset_nodes, subset_weights
def prune_conv2d(module: Conv2d, in_keep_idxes=None, out_keep_idxes=None):
    if in_keep_idxes is None:
        in_keep_idxes = list(range(module.weight.shape[1]))

    if out_keep_idxes is None:
        out_keep_idxes = list(range(module.weight.shape[0]))

    is_depthwise = is_depthwise_conv2d(module)

    if is_depthwise:
        module.groups = len(in_keep_idxes)
        assert len(in_keep_idxes) == len(out_keep_idxes)
    else:
        assert (
            len(in_keep_idxes) <= module.weight.shape[1]
        ), f"len(in_keep_idxes): {len(in_keep_idxes)}, module.weight.shape[1]: {module.weight.shape[1]}"

    assert (
        len(out_keep_idxes) <= module.weight.shape[0]
    ), f"len(out_keep_idxes): {len(out_keep_idxes)}, module.weight.shape[0]: {module.weight.shape[0]}"

    module.out_channels = len(out_keep_idxes)
    module.in_channels = len(in_keep_idxes)

    module.weight = torch.nn.Parameter(module.weight.data[out_keep_idxes, :, :, :])

    if not is_depthwise:
        module.weight = torch.nn.Parameter(module.weight.data[:, in_keep_idxes, :, :])

    module.weight.grad = None

    if module.bias is not None:
        module.bias = torch.nn.Parameter(module.bias.data[out_keep_idxes])
        module.bias.grad = None

    return in_keep_idxes, out_keep_idxes
Beispiel #3
0
def fuse_conv_bn(conv: nn.Conv2d, bn: nn.BatchNorm2d) -> nn.Conv2d:
    A, bias = convert_bn_params(bn)
    conv.weight.data.mul_(A.transpose(0, 1))
    conv.bias = nn.Parameter(bias.squeeze())

    return conv