def test_fbnet_flops(self): for x in ["fbnet_a", "fbnet_cse", "dmasking_f1"]: print(f"model name: {x}") model = fbnet(x, pretrained=False) res = model.arch_def.get("input_size", 224) input = torch.zeros([1, 3, res, res]) flops_utils.print_model_flops(model, input)
def main( cfg, output_dir, runner, # binary specific optional arguments predictor_types: typing.List[str], device: str = "cpu", compare_accuracy: bool = False, skip_if_fail: bool = False, ): if compare_accuracy: raise NotImplementedError( "compare_accuracy functionality isn't currently supported.") # NOTE: dict for metrics of all exported models (and original pytorch model) # ret["accuracy_comparison"] = accuracy_comparison cfg = copy.deepcopy(cfg) setup_after_launch(cfg, output_dir, runner) with temp_defrost(cfg): cfg.merge_from_list(["MODEL.DEVICE", device]) model = runner.build_model(cfg, eval_only=True) # NOTE: train dataset is used to avoid leakage since the data might be used for # running calibration for quantization. test_loader is used to make sure it follows # the inference behaviour (augmentation will not be applied). datasets = list(cfg.DATASETS.TRAIN) data_loader = runner.build_detection_test_loader(cfg, datasets) logger.info("Running the pytorch model and print FLOPS ...") first_batch = next(iter(data_loader)) input_args = (first_batch, ) flops_utils.print_model_flops(model, input_args) predictor_paths: typing.Dict[str, str] = {} for typ in predictor_types: # convert_and_export_predictor might alter the model, copy before calling it pytorch_model = copy.deepcopy(model) try: predictor_path = convert_and_export_predictor( cfg, pytorch_model, typ, output_dir, data_loader, ) logger.info( f"Predictor type {typ} has been exported to {predictor_path}") predictor_paths[typ] = predictor_path except Exception as e: logger.exception(f"Export {typ} predictor failed: {e}") if not skip_if_fail: raise e ret = {"predictor_paths": predictor_paths, "accuracy_comparison": {}} return ret
def main(config, input_shape=(320, 320)): model = build_model(config.model) try: import mobile_cv.lut.lib.pt.flops_utils as flops_utils except ImportError: print("mobile-cv is not installed. Skip flops calculation.") return first_batch = torch.rand((1, 3, input_shape[0], input_shape[1])) input_args = (first_batch, ) flops_utils.print_model_flops(model, input_args)
def test_fbnet_flops(self): for x in [ "fbnet_c", # "FBNetV2_F1", # "FBNetV2_F5", ]: print(f"model name: {x}") model = fbnet(x, pretrained=False) res = model.arch_def.get("input_size", 224) inputs = (torch.zeros([1, 3, res, res]), ) flops_utils.print_model_flops(model, inputs)
def run_flops_estimation(): # fbnet models, supported models could be found in # mobile_cv/model_zoo/models/model_info/fbnet_v2/*.json model_name = "dmasking_l2_hs" model = fbnet(model_name, pretrained=False) model.eval() res = model.arch_def.get("input_size", 224) input_batch = torch.zeros([1, 3, res, res]) with torch.no_grad(): flops_utils.print_model_flops(model, input_batch)
def _create_and_run(self, arch_name, model_arch): arch = fbnet_builder.unify_arch_def(model_arch, ["blocks"]) builder = fbnet_builder.FBNetBuilder() model = builder.build_blocks(arch["blocks"], dim_in=3) model.eval() res = model_arch.get("input_size", 224) inputs = torch.zeros([1, 3, res, res]) output = flops_utils.print_model_flops(model, inputs) self.assertEqual(output.shape[0], 1)
def main( args, output_dir, export_formats: typing.List[str] = DEFAULT_EXPORT_FORMATS, raise_if_failed: bool = False, ): _import_tasks(args.task) task = task_factory.get(args.task, **args.task_args) model = task.get_model() model.eval() data_loader = task.get_dataloader() data_iter = iter(data_loader) first_batch = next(data_iter) with torch.no_grad(): flops_utils.print_model_flops(model, first_batch) ret = {} for ef in export_formats: assert ef not in ret, f"Export format {ef} has already existed." try: out_path = ExportFactory.get(ef)( args, task, model, first_batch, output_dir, # NOTE: output model maybe difference if data_loader is used multiple times data_iter=data_iter, ) ret[ef] = out_path except Exception as e: logger.warning(f"Export format {ef} failed: {e}") if raise_if_failed: raise e return ret