예제 #1
0
파일: common.py 프로젝트: xue1234730/Prune
def _conv_raw_weight_hook(conv: nn.Conv2d, input_data: torch.Tensor,
                          output_data: torch.Tensor):
    """
    a hook to set `d_flops_in` and `d_flops_out` for each convolution layers

    `d_flops_in`: the FLOPs drops when the input channel drops by 1

    `d_flops_out`: the FLOPs drops when the output channel drops by 1
    """
    if conv.groups != 1:
        # for SparseGate and deep-wise layer in MobileNet v2
        # note the `d_flops_in` and `d_flops_out` of SparseGate should NOT be used

        # in MobileNet v2, the groups will change according to the input channel and output channel
        assert conv.groups == conv.in_channels and conv.groups == conv.out_channels

    output_channels, output_height, output_width = output_data[0].size()

    if conv.groups == 1:
        new_conv_groups = conv.groups
    else:
        # the conv_groups will change according to the input channel and output channel
        new_conv_groups = conv.groups - 1

    kernel_ops = conv.kernel_size[0] * conv.kernel_size[1] * (
        conv.in_channels / new_conv_groups)
    d_kernel_ops_in = conv.kernel_size[0] * conv.kernel_size[1] * (
        1 / new_conv_groups)

    # flops = kernel_ops * output_channels * output_height * output_width
    if conv.groups == 1:
        # normal conv layer
        conv.d_flops_in = d_kernel_ops_in * output_channels * output_height * output_width
        conv.d_flops_out = kernel_ops * 1 * output_height * output_width
        log_list = [
            conv.in_channels, conv.out_channels, conv.kernel_size[0],
            output_height, conv.d_flops_in, conv.d_flops_out
        ]
        log_list = [str(i) for i in log_list]
        print(",".join(log_list))
    else:
        # for deepwise layer
        # this layer will not be pruned, so do not set d_flops_out
        conv.d_flops_in = d_kernel_ops_in * (output_channels -
                                             1) * output_height * output_width