def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False, force_cpu=False): if force_cpu: model = model.to('cpu') device, dtype = next(model.parameters()).device, next(model.parameters()).dtype example_input = torch.ones((batch_size,) + input_size, device=device, dtype=dtype) fca = FlopCountAnalysis(model, example_input) aca = ActivationCountAnalysis(model, example_input) if detailed: fcs = flop_count_str(fca) print(fcs) return fca.total(), aca.total()