def dump_flops_info(model, inputs, output_dir): """ Dump flops information about model, using the given model inputs. Information are dumped to output_dir using various flop counting tools in different formats. Only a simple table is printed to terminal. """ if not comm.is_main_process(): return logger.info("Evaluating model's number of parameters and FLOPS") model = copy.deepcopy(model) model.eval() # 1. using mobile_cv flop counter try: fest = flops_utils.FlopsEstimation(model) with fest.enable(): model(inputs) fest.add_flops_info() model_str = str(model) output_file = os.path.join(output_dir, "flops_str_mobilecv.txt") with PathManager.open(output_file, "w") as f: f.write(model_str) logger.info(f"Flops info written to {output_file}") except Exception: logger.exception( "Failed to estimate flops using mobile_cv's FlopsEstimation") # 2. using d2/fvcore's flop counter try: flops = FlopCountAnalysis(model, inputs) # 2.1: dump as model str model_str = flop_count_str(flops) output_file = os.path.join(output_dir, "flops_str_fvcore.txt") with PathManager.open(output_file, "w") as f: f.write(model_str) logger.info(f"Flops info written to {output_file}") # 2.2: dump as table flops_table = flop_count_table(flops, max_depth=10) output_file = os.path.join(output_dir, "flops_table_fvcore.txt") with PathManager.open(output_file, "w") as f: f.write(flops_table) logger.info(f"Flops table written to {output_file}") # 2.3: print a table with a shallow depth flops_table = flop_count_table(flops, max_depth=3) logger.info("Flops table:\n" + flops_table) except Exception: logger.exception( "Failed to estimate flops using detectron2's FlopCountAnalysis")
def do_flop(cfg): if isinstance(cfg, CfgNode): data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) model = build_model(cfg) DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS) else: data_loader = instantiate(cfg.dataloader.test) model = instantiate(cfg.model) model.to(cfg.train.device) DetectionCheckpointer(model).load(cfg.train.init_checkpoint) model.eval() counts = Counter() total_flops = [] for idx, data in zip(tqdm.trange(args.num_inputs), data_loader): # noqa flops = FlopCountAnalysis(model, data) if idx > 0: flops.unsupported_ops_warnings(False).uncalled_modules_warnings( False) counts += flops.by_operator() total_flops.append(flops.total()) logger.info("Flops table computed from only one input sample:\n" + flop_count_table(flops)) logger.info("Average GFlops for each type of operators:\n" + str([(k, v / (idx + 1) / 1e9) for k, v in counts.items()])) logger.info("Total GFlops: {:.1f}±{:.1f}".format( np.mean(total_flops) / 1e9, np.std(total_flops) / 1e9))
def test_detr_fbnet_export(self): runner = create_runner("d2go.projects.detr.runner.DETRRunner") cfg = runner.get_default_cfg() cfg.MODEL.DEVICE = "cpu" # DETR self._set_detr_cfg(cfg, 3, 3, 50, 256) # backbone cfg.MODEL.BACKBONE.NAME = "FBNetV2C4Backbone" cfg.MODEL.FBNET_V2.ARCH = "FBNetV3_A_dsmask_C5" cfg.MODEL.FBNET_V2.WIDTH_DIVISOR = 8 cfg.MODEL.FBNET_V2.OUT_FEATURES = ["trunk4"] # build model model = runner.build_model(cfg).eval() model = model.detr print(model) scripted_model = torch.jit.script(model) self._assert_model_output(model, scripted_model) # print flops table = flop_count_table( FlopCountAnalysis(model, ([torch.rand(3, 224, 320)], ))) print(table)
def dump_flops_info(model, inputs, output_dir, use_eval_mode=True): """ Dump flops information about model, using the given model inputs. Information are dumped to output_dir using various flop counting tools in different formats. Only a simple table is printed to terminal. Args: inputs: a tuple of positional arguments used to call model with. use_eval_mode: turn the model into eval mode for flop counting. Otherwise, will use the original mode. It's recommended to use eval mode, because training mode typically follows a different codepath. """ if not comm.is_main_process(): return logger.info("Evaluating model's number of parameters and FLOPS") try: model = copy.deepcopy(model) except Exception: logger.info("Failed to deepcopy the model and skip FlopsEstimation.") return # delete other forward_pre_hooks so they are not simultaneously called for k in model._forward_pre_hooks: del model._forward_pre_hooks[k] if use_eval_mode: model.eval() inputs = copy.deepcopy(inputs) # 1. using mobile_cv flop counter try: fest = flops_utils.FlopsEstimation(model) with fest.enable(): model(*inputs) fest.add_flops_info() model_str = str(model) output_file = os.path.join(output_dir, "flops_str_mobilecv.txt") with PathManager.open(output_file, "w") as f: f.write(model_str) logger.info(f"Flops info written to {output_file}") except Exception: logger.exception( "Failed to estimate flops using mobile_cv's FlopsEstimation") # 2. using d2/fvcore's flop counter output_file = os.path.join(output_dir, "flops_str_fvcore.txt") try: flops = FlopCountAnalysis(model, inputs) # 2.1: dump as model str model_str = flop_count_str(flops) with PathManager.open(output_file, "w") as f: f.write(model_str) logger.info(f"Flops info written to {output_file}") # 2.2: dump as table flops_table = flop_count_table(flops, max_depth=10) output_file = os.path.join(output_dir, "flops_table_fvcore.txt") with PathManager.open(output_file, "w") as f: f.write(flops_table) logger.info(f"Flops table (full version) written to {output_file}") # 2.3: print a table with a shallow depth flops_table = flop_count_table(flops, max_depth=3) logger.info("Flops table:\n" + flops_table) except Exception: with PathManager.open(output_file, "w") as f: traceback.print_exc(file=f) logger.warning( "Failed to estimate flops using detectron2's FlopCountAnalysis. " f"Error written to {output_file}.") flops = float("nan") return flops