def count_conv2d(m: nn.Conv2d, inputs: torch.Tensor, outputs: torch.Tensor): """ Counts the #params and #ops in a conv2d layer """ inputs = check_inputs(inputs) cin = m.in_channels cout = m.out_channels x_stride, y_stride = m.stride x_kernel, y_kernel = m.kernel_size total_ops = cin * cout * x_kernel * y_kernel * ( inputs.size(-2) / x_stride) * (inputs.size(-1) / y_stride) / m.groups * 2 m.total_ops = torch.Tensor([int(total_ops)])