def _conv_raw_weight_hook(conv: nn.Conv2d, input_data: torch.Tensor, output_data: torch.Tensor): """ a hook to set `d_flops_in` and `d_flops_out` for each convolution layers `d_flops_in`: the FLOPs drops when the input channel drops by 1 `d_flops_out`: the FLOPs drops when the output channel drops by 1 """ if conv.groups != 1: # for SparseGate and deep-wise layer in MobileNet v2 # note the `d_flops_in` and `d_flops_out` of SparseGate should NOT be used # in MobileNet v2, the groups will change according to the input channel and output channel assert conv.groups == conv.in_channels and conv.groups == conv.out_channels output_channels, output_height, output_width = output_data[0].size() if conv.groups == 1: new_conv_groups = conv.groups else: # the conv_groups will change according to the input channel and output channel new_conv_groups = conv.groups - 1 kernel_ops = conv.kernel_size[0] * conv.kernel_size[1] * ( conv.in_channels / new_conv_groups) d_kernel_ops_in = conv.kernel_size[0] * conv.kernel_size[1] * ( 1 / new_conv_groups) # flops = kernel_ops * output_channels * output_height * output_width if conv.groups == 1: # normal conv layer conv.d_flops_in = d_kernel_ops_in * output_channels * output_height * output_width conv.d_flops_out = kernel_ops * 1 * output_height * output_width log_list = [ conv.in_channels, conv.out_channels, conv.kernel_size[0], output_height, conv.d_flops_in, conv.d_flops_out ] log_list = [str(i) for i in log_list] print(",".join(log_list)) else: # for deepwise layer # this layer will not be pruned, so do not set d_flops_out conv.d_flops_in = d_kernel_ops_in * (output_channels - 1) * output_height * output_width