Beispiel #1
0
def count_linear(m: nn.Linear, inputs: torch.Tensor, outputs: torch.Tensor):
    """ Counts the #params and #ops in a linear layer """
    cin = m.in_features
    cout = m.out_features
    total_ops = cin * cout
    m.total_ops = torch.Tensor([int(total_ops)])