def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')): # YOLOv5 TorchScript model export try: LOGGER.info( f'\n{prefix} starting export with torch {torch.__version__}...') f = file.with_suffix('.torchscript') ts = torch.jit.trace(model, im, strict=False) d = { "shape": im.shape, "stride": int(max(model.stride)), "names": model.names } extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap() if optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html optimize_for_mobile(ts)._save_for_lite_interpreter( str(f), _extra_files=extra_files) else: ts.save(str(f), _extra_files=extra_files) LOGGER.info( f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') return f except Exception as e: LOGGER.info(f'{prefix} export failure: {e}')
def test_preserve_bundled_inputs_methods(self): class MyBundledInputModule(torch.nn.Module): def __init__(self): super(MyBundledInputModule, self).__init__() def forward(self, inputs): return inputs class MyIncompleteBundledInputModule(torch.nn.Module): def __init__(self): super(MyIncompleteBundledInputModule, self).__init__() def forward(self, inputs): return inputs @torch.jit.export def get_all_bundled_inputs(self): pass bi_module = torch.jit.script(MyBundledInputModule()) module_optim_bi_not_preserved = optimize_for_mobile(bi_module) # Expected to be False since no bundled inputs methods were added self.assertFalse( hasattr(module_optim_bi_not_preserved, 'get_all_bundled_inputs') or hasattr(module_optim_bi_not_preserved, 'get_num_bundled_inputs')) # Add bundled inputs methods to the module torch.utils.bundled_inputs.augment_model_with_bundled_inputs( bi_module, [(torch.tensor([1]), )], []) # Now they should be preserved module_optim_bi_preserved = optimize_for_mobile(bi_module) # All of the bundled inputs methods were preserved self.assertTrue( hasattr(module_optim_bi_preserved, 'get_all_bundled_inputs') and hasattr(module_optim_bi_preserved, 'get_num_bundled_inputs')) bundled_input = module_optim_bi_preserved.get_all_bundled_inputs()[0] module_optim_bi_preserved(*bundled_input) # If not all 3 bundled inputs methods are present in the module, # we will not try to preserve them unless specified by the user. incomplete_bi_module = torch.jit.script( MyIncompleteBundledInputModule()) incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module) self.assertFalse( hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs')) # Specifically preserve get_all_bundled_inputs even if it's the only one # bundled inputs method available. incomplete_bi_module_optim = optimize_for_mobile( incomplete_bi_module, preserved_methods=['get_all_bundled_inputs']) self.assertTrue( hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs'))
def quantize(model, data_loader, config="fbgemm", name="lanes"): # Configuration prep_config_dict = {"non_traceable_module_name": ["base", "deconv"]} qconfig = get_default_qconfig(config) qconfig_dict = {"": qconfig} model.load() model.eval() # Prepare Model model_prepared = prepare_fx(model, qconfig_dict, prepare_custom_config_dict=prep_config_dict) calibrate(model_prepared, data_loader) model_int_8 = convert_fx(model_prepared) # Model Description params = sum([np.prod(p.size()) for p in model.parameters()]) print("ORIGINAL") print("Number of Parameters: {:.1f}M".format(params / 1e6)) print(f"Number of Parameters: {params}M") params = sum([np.prod(p.size()) for p in model_int_8.parameters()]) print("QUANTIZED") print("Number of Parameters: {:.6f}M".format(params / 1e6)) print(f"Number of Parameters: {params}M") print_size_of_model(model_int_8) mobile_model = torch.jit.script(model_int_8) torchscript_mobile = optimize_for_mobile(mobile_model) torch.jit.save(torchscript_mobile, MODEL_MAIN_DIR + name + "_mobile.pt") torch.jit.save(torch.jit.script(model_int_8), MODEL_MAIN_DIR + "quantized_" + name + "Net.pt") return model_int_8
def _main(): args = _parse_args() _init_logging(args.debug) loader = Loader() model = _get_model(args.model_file, args.dict_dir).eval() encoder = Encoder(model) decoder = _get_decoder() _LG.info(encoder) if args.quantize: _LG.info('Quantizing the model') model.encoder.transformer.pos_conv_embed.__prepare_scriptable__() encoder = tq.quantize_dynamic(encoder, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8) _LG.info(encoder) # test if args.test_file: _LG.info('Testing with %s', args.test_file) waveform = loader(args.test_file) emission = encoder(waveform) transcript = decoder(emission) _LG.info(transcript) torch.jit.script(loader).save(os.path.join(args.output_path, 'loader.zip')) torch.jit.script(decoder).save( os.path.join(args.output_path, 'decoder.zip')) scripted = torch.jit.script(encoder) if args.optimize_for_mobile: scripted = optimize_for_mobile(scripted) scripted.save(os.path.join(args.output_path, 'encoder.zip'))
def extract_torchscript(self): """Extract torchscript module from image classification model""" super().print_message("torchscript") scripted_module = torch.jit.script(self.model) optimized_module = mobile_optimizer.optimize_for_mobile( scripted_module) optimized_module.save("{}.pt".format(self.file_name))
def trace_and_save_torchscript( model: torch.nn.Module, inputs: typing.Tuple[typing.Any, ...], output_path: str, use_get_traceable=False, trace_type="trace", opt_for_mobile=False, ): logger.info("Tracing and saving TorchScript to {} ...".format(output_path)) with torch.no_grad(): if use_get_traceable: model = ju.get_traceable_model(model) if trace_type == "trace": script_model = torch.jit.trace(model, inputs, strict=False) else: script_model = torch.jit.script(model) if opt_for_mobile: logger.info("Running optimize_for_mobile...") script_model = optimize_for_mobile(script_model) os.makedirs(output_path, exist_ok=True) model_file = os.path.join(output_path, "model.jit") script_model.save(model_file) data_file = os.path.join(output_path, "data.pth") torch.save(inputs, data_file) return model_file
def save_mobile_model(): model.eval() example = torch.rand(1, 3, 480, 640).to(device) traced_script_module = torch.jit.trace(model, example) optimized_traced_model = optimize_for_mobile(traced_script_module) optimized_traced_model.save("mobile_model.pt") print('===> Mobile model saved!')
def extract_torchscript(self): """Extract torchscript module from sentiment analysis model""" super().print_message("torchscript") scripted_module = torch.jit.trace(self.model, self.model_example) optimized_module = mobile_optimizer.optimize_for_mobile( scripted_module) optimized_module.save("{}.pt".format(self.file_name))
def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')): # TorchScript model export try: LOGGER.info( f'\n{prefix} starting export with torch {torch.__version__}...') f = file.with_suffix('.torchscript.pt') ts = torch.jit.trace(model, im, strict=False) d = { "shape": im.shape, "stride": int(max(model.stride)), "names": model.names } extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap() (optimize_for_mobile(ts) if optimize else ts).save( f, _extra_files=extra_files) LOGGER.info( f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') except Exception as e: LOGGER.info(f'{prefix} export failure: {e}')
def test_export_torchvision_format(): cfg_name = 'faster_rcnn_fbnetv3a_dsmask_C4.yaml' pytorch_model = model_zoo.get(cfg_name, trained=True) from typing import List, Dict class Wrapper(torch.nn.Module): def __init__(self, model): super().__init__() self.model = model coco_idx_list = [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90, 91 ] self.coco_idx = torch.tensor(coco_idx_list) def forward(self, inputs: List[torch.Tensor]): x = inputs[0].unsqueeze(0) * 255 scale = 320.0 / min(x.shape[-2], x.shape[-1]) x = torch.nn.functional.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=True, recompute_scale_factor=True) out = self.model(x[0]) res: Dict[str, torch.Tensor] = {} res["boxes"] = out[0] / scale res["labels"] = torch.index_select(self.coco_idx, 0, out[1]) res["scores"] = out[2] return inputs, [res] size_divisibility = max(pytorch_model.backbone.size_divisibility, 10) h, w = size_divisibility, size_divisibility * 2 runner = create_runner("d2go.runner.GeneralizedRCNNRunner") cfg = model_zoo.get_config(cfg_name) datasets = list(cfg.DATASETS.TRAIN) data_loader = runner.build_detection_test_loader(cfg, datasets) predictor_path = convert_and_export_predictor( cfg, copy.deepcopy(pytorch_model), "torchscript_int8@tracing", './', data_loader, ) orig_model = torch.jit.load(os.path.join(predictor_path, "model.jit")) wrapped_model = Wrapper(orig_model) # optionally do a forward wrapped_model([torch.rand(3, 600, 600)]) scripted_model = torch.jit.script(wrapped_model) optimized_model = optimize_for_mobile(scripted_model) optimized_model.save("D2Go/d2go_optimized.pt")
def _benchmark_conv3d_pw_bn_relu_forward(**kwargs) -> Callable: assert kwargs["mode"] in ("original", "deployable"), ( "kwargs['mode'] must be either 'original' or 'deployable'," "but got {}.".format(kwargs["mode"])) input_tensor = torch.randn((kwargs["input_blob_size"])) conv_block = Conv3dPwBnAct( kwargs["in_channels"], kwargs["out_channels"], use_bn=False, # assume BN has already been fused for forward ) if kwargs["mode"] == "deployable": conv_block.convert(kwargs["input_blob_size"]) conv_block.eval() if kwargs["quantize"] is True: if kwargs["mode"] == "original": # manually fuse conv and relu conv_block.kernel = torch.quantization.fuse_modules( conv_block.kernel, ["conv", "act.act"]) conv_block = nn.Sequential( torch.quantization.QuantStub(), conv_block, torch.quantization.DeQuantStub(), ) conv_block.qconfig = torch.quantization.get_default_qconfig( "qnnpack") conv_block = torch.quantization.prepare(conv_block) try: conv_block = torch.quantization.convert(conv_block) except Exception as e: logging.info( "benchmark_conv3d_pw_bn_relu: " "catch exception '{}' with kwargs of {}".format( e, kwargs)) def func_to_benchmark_dummy() -> None: return return func_to_benchmark_dummy traced_model = torch.jit.trace(conv_block, input_tensor, strict=False) if kwargs["quantize"] is False: traced_model = optimize_for_mobile(traced_model) logging.info(f"model arch: {traced_model}") def func_to_benchmark() -> None: try: _ = traced_model(input_tensor) except Exception as e: logging.info( "benchmark_conv3d_pw_bn_relu: " "catch exception '{}' with kwargs of {}".format( e, kwargs)) return return func_to_benchmark
def script_and_serialize(model: nn.Module, path: str, opt_backend: Optional[str] = None): scripted_model = torch.jit.script(model) if opt_backend: scripted_model = optimize_for_mobile(script_module=scripted_model, backend=opt_backend) torch.jit.save(scripted_model, path)
def _benchmark_x3d_bottleneck_forward(**kwargs) -> Callable: assert kwargs["mode"] in ("original", "deployable"), ( "kwargs['mode'] must be either 'original' or 'deployable'," "but got {}.".format(kwargs["mode"])) input_tensor = torch.randn((kwargs["input_blob_size"])) conv_block = X3dBottleneckBlock( kwargs["in_channels"], kwargs["mid_channels"], kwargs["out_channels"], use_bn=(False, False, False), # Assume BN has been fused for forward ) if kwargs["mode"] == "deployable": conv_block.convert(kwargs["input_blob_size"]) conv_block.eval() if kwargs["quantize"] is True: conv_block = nn.Sequential( torch.quantization.QuantStub(), conv_block, torch.quantization.DeQuantStub(), ) conv_block.qconfig = torch.quantization.get_default_qconfig( "qnnpack") conv_block = torch.quantization.prepare(conv_block) try: conv_block = torch.quantization.convert(conv_block) except Exception as e: logging.info( "benchmark_x3d_bottleneck_forward: " "catch exception '{}' with kwargs of {}".format( e, kwargs)) def func_to_benchmark_dummy() -> None: return return func_to_benchmark_dummy traced_model = torch.jit.trace(conv_block, input_tensor, strict=False) if kwargs["quantize"] is False: traced_model = optimize_for_mobile(traced_model) logging.info(f"model arch: {traced_model}") def func_to_benchmark() -> None: try: _ = traced_model(input_tensor) except Exception as e: logging.info( "benchmark_x3d_bottleneck_forward: " "catch exception '{}' with kwargs of {}".format( e, kwargs)) return return func_to_benchmark
def trace_and_save_torchscript( model: nn.Module, inputs: Tuple[torch.Tensor], output_path: str, mobile_optimization: Optional[MobileOptimizationConfig] = None, _extra_files: Optional[Dict[str, bytes]] = None, ): logger.info("Tracing and saving TorchScript to {} ...".format(output_path)) PathManager.mkdirs(output_path) if _extra_files is None: _extra_files = {} # TODO: patch_builtin_len depends on D2, we should either copy the function or # dynamically registering the D2's version. from detectron2.export.torchscript_patch import patch_builtin_len with torch.no_grad(), patch_builtin_len(): script_model = torch.jit.trace(model, inputs) with make_temp_directory("trace_and_save_torchscript") as tmp_dir: @contextlib.contextmanager def _synced_local_file(rel_path): remote_file = os.path.join(output_path, rel_path) local_file = os.path.join(tmp_dir, rel_path) yield local_file PathManager.copy_from_local(local_file, remote_file, overwrite=True) with _synced_local_file("model.jit") as model_file: torch.jit.save(script_model, model_file, _extra_files=_extra_files) with _synced_local_file("data.pth") as data_file: torch.save(inputs, data_file) if mobile_optimization is not None: logger.info("Applying optimize_for_mobile ...") liteopt_model = optimize_for_mobile( script_model, optimization_blocklist=mobile_optimization. optimization_blocklist, preserved_methods=mobile_optimization.preserved_methods, backend=mobile_optimization.backend, ) with _synced_local_file("mobile_optimized.ptl") as lite_path: liteopt_model._save_for_lite_interpreter(lite_path) # liteopt_model(*inputs) # sanity check op_names = torch.jit.export_opnames(liteopt_model) logger.info("Operator names from lite interpreter:\n{}".format( "\n".join(op_names))) logger.info("Applying augment_model_with_bundled_inputs ...") augment_model_with_bundled_inputs(liteopt_model, [inputs]) liteopt_model.run_on_bundled_input(0) # sanity check with _synced_local_file( "mobile_optimized_bundled.ptl") as lite_path: liteopt_model._save_for_lite_interpreter(lite_path)
def _quant_script_and_optimize(model): model.qconfig = torch.quantization.get_default_qconfig('qnnpack') model.fuse_model() torch.quantization.prepare(model, inplace=True) model(torch.randn(4, 1, 4, 4)) torch.quantization.convert(model, inplace=True) model = torch.jit.script(model) model_optim = optimize_for_mobile(model) return model, model_optim
def test_all_backport_functions(self): # Backport from the latest bytecode version to the minimum support version # Load, run the backport model, and check version class TestModule(torch.nn.Module): def __init__(self, v): super().__init__() self.x = v def forward(self, y: int): increment = torch.ones([2, 4], dtype=torch.float64) return self.x + y + increment module_input = 1 expected_mobile_module_result = 3 * torch.ones([2, 4], dtype=torch.float64) # temporary input model file and output model file will be exported in the temporary folder with tempfile.TemporaryDirectory() as tmpdirname: tmp_input_model_path = Path(tmpdirname, "tmp_script_module.ptl") script_module = torch.jit.script(TestModule(1)) optimized_scripted_module = optimize_for_mobile(script_module) exported_optimized_scripted_module = optimized_scripted_module._save_for_lite_interpreter( str(tmp_input_model_path)) current_from_version = _get_model_bytecode_version( tmp_input_model_path) current_to_version = current_from_version - 1 tmp_output_model_path = Path(tmpdirname, "tmp_script_module_backport.ptl") while current_to_version >= MINIMUM_TO_VERSION: # Backport the latest model to `to_version` to a tmp file "tmp_script_module_backport" backport_success = _backport_for_mobile( tmp_input_model_path, tmp_output_model_path, current_to_version) assert (backport_success) backport_version = _get_model_bytecode_version( tmp_output_model_path) assert (backport_version == current_to_version) # Load model and run forward method mobile_module = _load_for_lite_interpreter( str(tmp_input_model_path)) mobile_module_result = mobile_module(module_input) torch.testing.assert_allclose(mobile_module_result, expected_mobile_module_result) current_to_version -= 1 # Check backport failure case backport_success = _backport_for_mobile(tmp_input_model_path, tmp_output_model_path, MINIMUM_TO_VERSION - 1) assert (not backport_success) # need to clean the folder before it closes, otherwise will run into git not clean error shutil.rmtree(tmpdirname)
def test_mobilenet_optimize_for_mobile(self): m = torchvision.models.mobilenet_v3_small() m = torch.jit.script(m) m = optimize_for_mobile(m) # run forward 3 times until segfault, see https://github.com/pytorch/pytorch/issues/52463 x = torch.zeros(1, 3, 56, 56) self.assertEqual(m(x).numel(), 1000) self.assertEqual(m(x).numel(), 1000) self.assertEqual(m(x).numel(), 1000)
def initialize_model(self): if self._torchscript_model is None: self._model = self._lazy_model() Path("jit").mkdir(exist_ok=True, parents=True) self._torchscript_model = torch.jit.trace(self._model, self.data.eval, check_trace=False) self._torchscript_model = mobile_optimizer.optimize_for_mobile( self._torchscript_model) self._torchscript_model.save(f"jit/{self.name}.pt")
def trace_and_serialize(model: nn.Module, example: torch.Tensor, path: str, opt_backend: Optional[str] = None): with torch.no_grad(): traced_model = torch.jit.trace(model, example_inputs=example) if opt_backend: traced_model = optimize_for_mobile(script_module=traced_model, backend=opt_backend) torch.jit.save(traced_model, path)
def convert_to_mobile(model_name, path, dest): if model_name == 'concat': model = ConcatModel(2) elif model_name == 'vgg': model = Vgg16(2) else: model = ResNet50(2) model = convert(model, path) script_model = torch.jit.script(model) mobile_model = mobile_optimizer.optimize_for_mobile(script_model) torch.jit.save(mobile_model, dest)
def validate_transform_conv1d_to_conv2d(self, pattern_count_transformed_map, pattern_count_optimized_map, data_shape): module_instance = self scripted_model = torch.jit.script(module_instance) scripted_model.eval() input_data = torch.normal(1, 20, size=data_shape) ref_result = scripted_model(input_data) torch._C._jit_pass_transform_conv1d_to_conv2d(scripted_model._c) optimized_scripted_model = optimize_for_mobile(scripted_model) buffer = io.BytesIO() torch.jit.save(scripted_model, buffer) buffer.seek(0) deserialized_scripted_model = torch.jit.load(buffer) for pattern, v in pattern_count_transformed_map.items(): if (v == 0): FileCheck().check(pattern).run( deserialized_scripted_model.graph) elif (v == -1): FileCheck().check_not(pattern).run( deserialized_scripted_model.graph) else: FileCheck().check_count(pattern, v, exactly=True).run( deserialized_scripted_model.graph) transformed_result = deserialized_scripted_model(input_data) torch.testing.assert_allclose(ref_result, transformed_result, rtol=1e-2, atol=1e-3) optimized_buffer = io.BytesIO() torch.jit.save(optimized_scripted_model, optimized_buffer) optimized_buffer.seek(0) deserialized_optimized_scripted_model = torch.jit.load( optimized_buffer) for pattern, v in pattern_count_optimized_map.items(): if (v == 0): FileCheck().check(pattern).run( deserialized_optimized_scripted_model.graph) elif (v == -1): FileCheck().check_not(pattern).run( deserialized_optimized_scripted_model.graph) else: FileCheck().check_count(pattern, v, exactly=True).run( deserialized_optimized_scripted_model.graph) xnnpack_result = deserialized_optimized_scripted_model(input_data) torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
def export_torchscript(model, img, file, optimize): # TorchScript model export prefix = colorstr('TorchScript:') try: print(f'\n{prefix} starting export with torch {torch.__version__}...') f = file.with_suffix('.torchscript.pt') ts = torch.jit.trace(model, img, strict=False) (optimize_for_mobile(ts) if optimize else ts).save(f) print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') return ts except Exception as e: print(f'{prefix} export failure: {e}')
def save_model_for_mobile(self, width: int, height: int, out_filepath: str, for_os="cpu"): torch_model = self.to("cpu") torch_model.eval() if for_os == "cpu": example = torch.rand(1, 3, height, width).to("cpu") traced_script_module = torch.jit.trace(torch_model, example) traced_script_module.save(out_filepath) return script_model = torch.jit.script(torch_model) if for_os == "android": mobile_optimizer.optimize_for_mobile(script_model, backend="Vulkan") elif for_os == "ios": mobile_optimizer.optimize_for_mobile(script_model, backend="metal") torch.jit.save(script_model, out_filepath)
def getModule(self): model = torchvision.models.mobilenet_v2(pretrained=True) model.eval() example = torch.zeros(1, 3, 224, 224) traced_script_module = torch.jit.trace(model, example) optimized_module = optimize_for_mobile(traced_script_module) augment_model_with_bundled_inputs( optimized_module, [ (example, ), ], ) optimized_module(example) return optimized_module
def test_save_mobile_module_with_debug_info_with_script_nested_call(self): class A(torch.nn.Module): def __init__(self): super(A, self).__init__() def forward(self, x): return x + 1 class B(torch.nn.Module): def __init__(self): super(B, self).__init__() def forward(self, x): return x + 2 class C(torch.nn.Module): def __init__(self): super(C, self).__init__() self.A0 = A() self.B0 = B() def forward(self, x): return self.A0(self.B0(x)) + 1 input = torch.tensor([5]) scripted_module = torch.jit.script(C(), input) optimized_scripted_module = optimize_for_mobile(scripted_module) exported_module = scripted_module._save_to_buffer_for_lite_interpreter( _save_mobile_debug_info=True) optimized_exported_module = optimized_scripted_module._save_to_buffer_for_lite_interpreter( _save_mobile_debug_info=True) assert (b"mobile_debug.pkl" in exported_module) assert (b"module_debug_info" in exported_module) assert (b"top(C).forward" in exported_module) assert (b"top(C).A0(A).forward" in exported_module) assert (b"top(C).B0(B).forward" in exported_module) assert (b"mobile_debug.pkl" in optimized_exported_module) assert (b"module_debug_info" in optimized_exported_module) assert (b"top(C).forward" in optimized_exported_module) assert (b"top(C).A0(A).forward" in optimized_exported_module) assert (b"top(C).B0(B).forward" in optimized_exported_module)
def test_quantized_conv_no_asan_failures(self): # There were ASAN failures when fold_conv_bn was run on # already quantized conv modules. Verifying that this does # not happen again. if 'qnnpack' not in torch.backends.quantized.supported_engines: return class Child(nn.Module): def __init__(self): super(Child, self).__init__() self.conv2 = nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv2(x) return x class Parent(nn.Module): def __init__(self): super(Parent, self).__init__() self.quant = torch.ao.quantization.QuantStub() self.conv1 = nn.Conv2d(1, 1, 1) self.child = Child() self.dequant = torch.ao.quantization.DeQuantStub() def forward(self, x): x = self.quant(x) x = self.conv1(x) x = self.child(x) x = self.dequant(x) return x with override_quantized_engine('qnnpack'): model = Parent() model.qconfig = torch.ao.quantization.get_default_qconfig( 'qnnpack') torch.ao.quantization.prepare(model, inplace=True) model(torch.randn(4, 1, 4, 4)) torch.ao.quantization.convert(model, inplace=True) model = torch.jit.script(model) # this line should not have ASAN failures model_optim = optimize_for_mobile(model)
def run( weights='./yolov5s.pt', # weights path img_size=(640, 640), # image (height, width) batch_size=1, # batch size device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu include=('torchscript', 'onnx', 'coreml'), # include formats half=False, # FP16 half-precision export inplace=False, # set YOLOv5 Detect() inplace=True train=False, # model.train() mode optimize=False, # TorchScript: optimize for mobile dynamic=False, # ONNX: dynamic axes simplify=False, # ONNX: simplify model opset_version=12, # ONNX: opset version ): t = time.time() include = [x.lower() for x in include] img_size *= 2 if len(img_size) == 1 else 1 # expand # Load PyTorch model device = select_device(device) assert not ( device.type == 'cpu' and opt.half ), '--half only compatible with GPU export, i.e. use --device 0' model = attempt_load(weights, map_location=device) # load FP32 model labels = model.names # Input gs = int(max(model.stride)) # grid size (max stride) img_size = [check_img_size(x, gs) for x in img_size] # verify img_size are gs-multiples img = torch.zeros(batch_size, 3, *img_size).to( device) # image size(1,3,320,192) iDetection # Update model if half: img, model = img.half(), model.half() # to FP16 model.train() if train else model.eval( ) # training mode = no Detect() layer grid construction for k, m in model.named_modules(): m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility if isinstance(m, Conv): # assign export-friendly activations if isinstance(m.act, nn.Hardswish): m.act = Hardswish() elif isinstance(m.act, nn.SiLU): m.act = SiLU() elif isinstance(m, Detect): m.inplace = inplace m.onnx_dynamic = dynamic # m.forward = m.forward_export # assign forward (optional) for _ in range(2): y = model(img) # dry runs print( f"\n{colorstr('PyTorch:')} starting from {weights} ({file_size(weights):.1f} MB)" ) # TorchScript export ----------------------------------------------------------------------------------------------- if 'torchscript' in include or 'coreml' in include: prefix = colorstr('TorchScript:') try: print( f'\n{prefix} starting export with torch {torch.__version__}...' ) f = weights.replace('.pt', '.torchscript.pt') # filename ts = torch.jit.trace(model, img, strict=False) (optimize_for_mobile(ts) if optimize else ts).save(f) print( f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)' ) except Exception as e: print(f'{prefix} export failure: {e}') # ONNX export ------------------------------------------------------------------------------------------------------ if 'onnx' in include: prefix = colorstr('ONNX:') try: import onnx print(f'{prefix} starting export with onnx {onnx.__version__}...') f = weights.replace('.pt', '.onnx') # filename torch.onnx.export( model, img, f, verbose=False, opset_version=opset_version, training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL, do_constant_folding=not train, input_names=['images'], output_names=['output'], dynamic_axes={ 'images': { 0: 'batch', 2: 'height', 3: 'width' }, # shape(1,3,640,640) 'output': { 0: 'batch', 1: 'anchors' } # shape(1,25200,85) } if dynamic else None) # Checks model_onnx = onnx.load(f) # load onnx model onnx.checker.check_model(model_onnx) # check onnx model # print(onnx.helper.printable_graph(model_onnx.graph)) # print # Simplify if simplify: try: check_requirements(['onnx-simplifier']) import onnxsim print( f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...' ) model_onnx, check = onnxsim.simplify( model_onnx, dynamic_input_shape=dynamic, input_shapes={'images': list(img.shape)} if dynamic else None) assert check, 'assert check failed' onnx.save(model_onnx, f) except Exception as e: print(f'{prefix} simplifier failure: {e}') print( f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)' ) except Exception as e: print(f'{prefix} export failure: {e}') # CoreML export ---------------------------------------------------------------------------------------------------- if 'coreml' in include: prefix = colorstr('CoreML:') try: import coremltools as ct print( f'{prefix} starting export with coremltools {ct.__version__}...' ) assert train, 'CoreML exports should be placed in model.train() mode with `python export.py --train`' model = ct.convert(ts, inputs=[ ct.ImageType('image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0]) ]) f = weights.replace('.pt', '.mlmodel') # filename model.save(f) print( f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)' ) except Exception as e: print(f'{prefix} export failure: {e}') # Finish print( f'\nExport complete ({time.time() - t:.2f}s). Visualize with https://github.com/lutzroeder/netron.' )
def test_optimize_for_mobile(self): batch_size = 2 input_channels_per_group = 6 height = 16 width = 16 output_channels_per_group = 6 groups = 4 kernel_h = kernel_w = 3 stride_h = stride_w = 1 pad_h = pad_w = 1 dilation = 1 input_channels = input_channels_per_group * groups output_channels = output_channels_per_group * groups kernels = (kernel_h, kernel_w) strides = (stride_h, stride_w) paddings = (pad_h, pad_w) dilations = (dilation, dilation) conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w) conv_bias_shape = (output_channels) input_data = torch.rand((batch_size, input_channels, height, width)) conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w)) conv_bias = torch.rand((output_channels)) result = F.conv2d(input_data, conv_weight, conv_bias, strides, paddings, dilations, groups) weight_output_dim = 24 linear_input_shape = result.shape[1] linear_weight_shape = (weight_output_dim, linear_input_shape) class MyTestModule(torch.nn.Module): def __init__(self): super(MyTestModule, self).__init__() self.conv_weight = torch.nn.Parameter(torch.rand(conv_weight_shape)) self.conv_bias = torch.nn.Parameter(torch.rand((conv_bias_shape))) self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape)) self.linear_bias = torch.nn.Parameter(torch.rand((weight_output_dim))) self.strides = strides self.paddings = paddings self.dilations = dilations self.groups = groups def forward(self, x): o = F.conv2d(x, self.conv_weight, self.conv_bias, self.strides, self.paddings, self.dilations, self.groups) o = F.relu(o) x = o.permute([0, 2, 3, 1]) o = F.linear(x, self.linear_weight, self.linear_bias) o = o + x return F.relu(o) @torch.jit.export def foo(self, x): o = F.conv2d(x, self.conv_weight, self.conv_bias, self.strides, self.paddings, self.dilations, self.groups) o = F.relu(o) x = o.permute([0, 2, 3, 1]) o = F.linear(x, self.linear_weight, self.linear_bias) o = o + x return F.relu(o) class BNTestModule(torch.nn.Module): def __init__(self): super(BNTestModule, self).__init__() self.conv = torch.nn.Conv2d(1, 20, 5, 1) self.bn = torch.nn.BatchNorm2d(num_features=20) self.bn.eps = 0.0023 def forward(self, x): x = self.conv(x) x = self.bn(x) return x data_shape = (batch_size, input_channels, height, width) input_data = torch.normal(1, 20, size=data_shape) scripted_model = torch.jit.script(MyTestModule()) scripted_model.eval() initial_result = scripted_model(input_data) initial_foo_result = scripted_model.foo(input_data) optimized_scripted_model = optimize_for_mobile(scripted_model, preserved_methods=['foo']) optimized_result = optimized_scripted_model(input_data) optimized_foo_result = optimized_scripted_model.foo(input_data) FileCheck().check_not("Tensor = aten::conv2d") \ .check_not("Tensor = prim::CallFunction") \ .check_not("prepacked::conv2d_clamp_prepack") \ .check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \ .check_not("prepacked::linear_clamp_prepack") \ .check_count("prepacked::linear_clamp_run", 1, exactly=True) \ .check_not("aten::add(") \ .check_not("aten::relu(") \ .check_count("aten::_add_relu(", 1, exactly=True) \ .run(optimized_scripted_model.graph) torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3) FileCheck().check_not("Tensor = aten::conv2d") \ .check_not("Tensor = prim::CallFunction") \ .check_not("prepacked::conv2d_clamp_prepack") \ .check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \ .check_not("prepacked::linear_clamp_prepack") \ .check_count("prepacked::linear_clamp_run", 1, exactly=True) \ .check_not("aten::add(") \ .check_not("aten::relu(") \ .check_count("aten::_add_relu(", 1, exactly=True) \ .run(optimized_scripted_model.foo.graph) torch.testing.assert_allclose(initial_foo_result, optimized_foo_result, rtol=1e-2, atol=1e-3) optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS} optimized_scripted_model_no_prepack = optimize_for_mobile(scripted_model, optimization_blocklist_no_prepack) optimized_result_no_prepack = optimized_scripted_model_no_prepack(input_data) FileCheck().check_count("Tensor = aten::conv2d", 1, exactly=True) \ .check_not("prepacked::linear_clamp_run") \ .check_not("prepacked::conv2d_clamp_run") \ .run(optimized_scripted_model_no_prepack.graph) torch.testing.assert_allclose(initial_result, optimized_result_no_prepack, rtol=1e-2, atol=1e-3) bn_test_module = BNTestModule() bn_scripted_module = torch.jit.script(bn_test_module) bn_scripted_module.eval() self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 14) FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \ .run(str(get_forward(bn_scripted_module._c).graph)) optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS} bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_prepack) self.assertEqual(len(torch.jit.export_opnames(bn_fold_scripted_module)), 1) bn_input = torch.rand(1, 1, 6, 6) torch.testing.assert_allclose(bn_scripted_module(bn_input), bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3) optimization_blocklist_no_fold_bn = {MobileOptimizerType.CONV_BN_FUSION} no_bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_fold_bn) FileCheck().check_count("aten::batch_norm", 1, exactly=True) \ .run(str(get_forward_graph(no_bn_fold_scripted_module._c))) bn_input = torch.rand(1, 1, 6, 6) torch.testing.assert_allclose(bn_scripted_module(bn_input), no_bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3) class MyMobileOptimizedTagTest(torch.nn.Module): def __init__(self): super(MyMobileOptimizedTagTest, self).__init__() self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape)) self.linear_bias = torch.nn.Parameter(torch.rand((weight_output_dim))) def forward(self, x): o = F.linear(x, self.linear_weight, self.linear_bias) return F.relu(o) mobile_optimized_tag_module = MyMobileOptimizedTagTest() m = torch.jit.script(mobile_optimized_tag_module) m.eval() opt_m = optimize_for_mobile(m) tag = getattr(opt_m, "mobile_optimized", None) self.assertTrue(tag) class MyPreserveMethodsTest(torch.nn.Module): def __init__(self): super(MyPreserveMethodsTest, self).__init__() self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape)) self.linear_bias = torch.nn.Parameter(torch.rand((weight_output_dim))) def forward(self, x): o = F.linear(x, self.linear_weight, self.linear_bias) return F.relu(o) @torch.jit.export def preserveThis(self): pass preserve_method_module = MyPreserveMethodsTest() m = torch.jit.script(preserve_method_module) m.eval() opt_m = optimize_for_mobile(m) no_preserveThis = getattr(opt_m, "preserveThis", None) self.assertEqual(no_preserveThis, None) opt_m = optimize_for_mobile(m, preserved_methods=["preserveThis"]) preserveThis = getattr(opt_m, "preserveThis", None) self.assertNotEqual(preserveThis, None) class OptimizeNoForwardTest(torch.nn.Module): def __init__(self): super(OptimizeNoForwardTest, self).__init__() self.l = nn.Linear(10, 100) self.l2 = nn.Linear(100, 1) self.d = nn.Dropout(p=0.2) @torch.jit.export def foo(self, x): x = self.d(F.relu(self.l(x))) x = self.l2(x) x = x + torch.ones(1, 100) return F.relu(x) input_data = torch.ones(1, 10) m = torch.jit.script(OptimizeNoForwardTest()) m.eval() initial_result = m.foo(input_data) optimized_scripted_model = optimize_for_mobile(m, preserved_methods=['foo']) optimized_result = optimized_scripted_model.foo(input_data) FileCheck().check_not("dropout.__") \ .check_count("aten::_add_relu(", 1, exactly=True) \ .run(optimized_scripted_model.foo.graph) torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3) class BNTestNoForwardModule(torch.nn.Module): def __init__(self): super(BNTestNoForwardModule, self).__init__() self.conv = torch.nn.Conv2d(1, 20, 5, 1) self.bn = torch.nn.BatchNorm2d(num_features=20) self.bn.eps = 0.0023 @torch.jit.export def foo(self, x): x = self.conv(x) x = self.bn(x) return x bn_test_no_forward_module = BNTestNoForwardModule() bn_no_forward_scripted_module = torch.jit.script(bn_test_no_forward_module) bn_no_forward_scripted_module.eval() self.assertEqual(len(torch.jit.export_opnames(bn_no_forward_scripted_module)), 14) FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \ .run(bn_no_forward_scripted_module.foo.graph) bn_fold_no_foward_scripted_module = optimize_for_mobile(bn_no_forward_scripted_module, preserved_methods=['foo']) self.assertEqual(len(torch.jit.export_opnames(bn_fold_no_foward_scripted_module)), 1) bn_input = torch.rand(1, 1, 6, 6) torch.testing.assert_allclose( bn_no_forward_scripted_module.foo(bn_input), bn_fold_no_foward_scripted_module.foo(bn_input), rtol=1e-2, atol=1e-3)
import torch import torchvision from torch.utils.mobile_optimizer import optimize_for_mobile model = torchvision.models.mobilenet_v2(pretrained=True) model.eval() example = torch.rand(1, 3, 224, 224) traced_script_module = torch.jit.trace(model, example) optimized_scripted_module = optimize_for_mobile(traced_script_module) torch.jit.save(optimized_scripted_module, '../models/model.pt') exported_optimized_scripted_module = optimized_scripted_module._save_for_lite_interpreter( "../models/model_lite.ptl")
def trace_and_save_torchscript( model: nn.Module, inputs: Tuple[torch.Tensor], output_path: str, torchscript_filename: str = "model.jit", mobile_optimization: Optional[MobileOptimizationConfig] = None, _extra_files: Optional[Dict[str, bytes]] = None, ): logger.info("Tracing and saving TorchScript to {} ...".format(output_path)) PathManager.mkdirs(output_path) if _extra_files is None: _extra_files = {} with torch.no_grad(): script_model = torch.jit.trace(model, inputs) with make_temp_directory("trace_and_save_torchscript") as tmp_dir: @contextlib.contextmanager def _synced_local_file(rel_path): remote_file = os.path.join(output_path, rel_path) local_file = os.path.join(tmp_dir, rel_path) yield local_file PathManager.copy_from_local(local_file, remote_file, overwrite=True) with _synced_local_file(torchscript_filename) as model_file: torch.jit.save(script_model, model_file, _extra_files=_extra_files) with _synced_local_file("data.pth") as data_file: torch.save(inputs, data_file) if mobile_optimization is not None: logger.info("Applying optimize_for_mobile ...") liteopt_model = optimize_for_mobile( script_model, optimization_blocklist=mobile_optimization. optimization_blocklist, preserved_methods=mobile_optimization.preserved_methods, backend=mobile_optimization.backend, ) torchscript_filename = mobile_optimization.torchscript_filename with _synced_local_file(torchscript_filename) as lite_path: liteopt_model._save_for_lite_interpreter( lite_path, _extra_files=_extra_files) # liteopt_model(*inputs) # sanity check op_names = torch.jit.export_opnames(liteopt_model) logger.info("Operator names from lite interpreter:\n{}".format( "\n".join(op_names))) logger.info("Applying augment_model_with_bundled_inputs ...") # make all tensors zero-like to save storage iters = recursive_iterate(inputs) for x in iters: if isinstance(x, torch.Tensor): iters.send(torch.zeros_like(x).contiguous()) inputs = iters.value augment_model_with_bundled_inputs(liteopt_model, [inputs]) liteopt_model( *liteopt_model.get_all_bundled_inputs()[0]) # sanity check name, ext = os.path.splitext(torchscript_filename) with _synced_local_file(name + "_bundled" + ext) as lite_path: liteopt_model._save_for_lite_interpreter(lite_path) return torchscript_filename