def _linear_raw_weight_hook(linear_layer: nn.Linear, input_data: torch.Tensor, output_data: torch.Tensor): input_dim = linear_layer.in_features output_dim = linear_layer.out_features # flops = linear_layer.weight.nelement() assert linear_layer.weight.nelement() == (input_dim * output_dim) linear_layer.d_flops_in = (input_dim - 1) * output_dim linear_layer.d_flops_out = input_dim * (output_dim - 1)