Пример #1
0
def export_torchscript(model,
                       im,
                       file,
                       optimize,
                       prefix=colorstr('TorchScript:')):
    # YOLOv5 TorchScript model export
    try:
        LOGGER.info(
            f'\n{prefix} starting export with torch {torch.__version__}...')
        f = file.with_suffix('.torchscript')

        ts = torch.jit.trace(model, im, strict=False)
        d = {
            "shape": im.shape,
            "stride": int(max(model.stride)),
            "names": model.names
        }
        extra_files = {'config.txt': json.dumps(d)}  # torch._C.ExtraFilesMap()
        if optimize:  # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
            optimize_for_mobile(ts)._save_for_lite_interpreter(
                str(f), _extra_files=extra_files)
        else:
            ts.save(str(f), _extra_files=extra_files)

        LOGGER.info(
            f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
        return f
    except Exception as e:
        LOGGER.info(f'{prefix} export failure: {e}')
Пример #2
0
    def test_preserve_bundled_inputs_methods(self):
        class MyBundledInputModule(torch.nn.Module):
            def __init__(self):
                super(MyBundledInputModule, self).__init__()

            def forward(self, inputs):
                return inputs

        class MyIncompleteBundledInputModule(torch.nn.Module):
            def __init__(self):
                super(MyIncompleteBundledInputModule, self).__init__()

            def forward(self, inputs):
                return inputs

            @torch.jit.export
            def get_all_bundled_inputs(self):
                pass

        bi_module = torch.jit.script(MyBundledInputModule())
        module_optim_bi_not_preserved = optimize_for_mobile(bi_module)

        # Expected to be False since no bundled inputs methods were added
        self.assertFalse(
            hasattr(module_optim_bi_not_preserved, 'get_all_bundled_inputs') or
            hasattr(module_optim_bi_not_preserved, 'get_num_bundled_inputs'))

        # Add bundled inputs methods to the module
        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
            bi_module, [(torch.tensor([1]), )], [])
        # Now they should be preserved
        module_optim_bi_preserved = optimize_for_mobile(bi_module)

        # All of the bundled inputs methods were preserved
        self.assertTrue(
            hasattr(module_optim_bi_preserved, 'get_all_bundled_inputs')
            and hasattr(module_optim_bi_preserved, 'get_num_bundled_inputs'))

        bundled_input = module_optim_bi_preserved.get_all_bundled_inputs()[0]
        module_optim_bi_preserved(*bundled_input)

        # If not all 3 bundled inputs methods are present in the module,
        # we will not try to preserve them unless specified by the user.
        incomplete_bi_module = torch.jit.script(
            MyIncompleteBundledInputModule())
        incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module)
        self.assertFalse(
            hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs'))

        # Specifically preserve get_all_bundled_inputs even if it's the only one
        # bundled inputs method available.
        incomplete_bi_module_optim = optimize_for_mobile(
            incomplete_bi_module, preserved_methods=['get_all_bundled_inputs'])
        self.assertTrue(
            hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs'))
Пример #3
0
def quantize(model, data_loader, config="fbgemm", name="lanes"):
    # Configuration
    prep_config_dict = {"non_traceable_module_name": ["base", "deconv"]}
    qconfig = get_default_qconfig(config)
    qconfig_dict = {"": qconfig}
    model.load()
    model.eval()
    # Prepare Model
    model_prepared = prepare_fx(model,
                                qconfig_dict,
                                prepare_custom_config_dict=prep_config_dict)

    calibrate(model_prepared, data_loader)
    model_int_8 = convert_fx(model_prepared)
    # Model Description
    params = sum([np.prod(p.size()) for p in model.parameters()])
    print("ORIGINAL")
    print("Number of Parameters: {:.1f}M".format(params / 1e6))
    print(f"Number of Parameters: {params}M")
    params = sum([np.prod(p.size()) for p in model_int_8.parameters()])
    print("QUANTIZED")
    print("Number of Parameters: {:.6f}M".format(params / 1e6))
    print(f"Number of Parameters: {params}M")

    print_size_of_model(model_int_8)

    mobile_model = torch.jit.script(model_int_8)
    torchscript_mobile = optimize_for_mobile(mobile_model)
    torch.jit.save(torchscript_mobile, MODEL_MAIN_DIR + name + "_mobile.pt")

    torch.jit.save(torch.jit.script(model_int_8),
                   MODEL_MAIN_DIR + "quantized_" + name + "Net.pt")

    return model_int_8
Пример #4
0
def _main():
    args = _parse_args()
    _init_logging(args.debug)
    loader = Loader()
    model = _get_model(args.model_file, args.dict_dir).eval()
    encoder = Encoder(model)
    decoder = _get_decoder()
    _LG.info(encoder)

    if args.quantize:
        _LG.info('Quantizing the model')
        model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
        encoder = tq.quantize_dynamic(encoder,
                                      qconfig_spec={torch.nn.Linear},
                                      dtype=torch.qint8)
        _LG.info(encoder)

    # test
    if args.test_file:
        _LG.info('Testing with %s', args.test_file)
        waveform = loader(args.test_file)
        emission = encoder(waveform)
        transcript = decoder(emission)
        _LG.info(transcript)

    torch.jit.script(loader).save(os.path.join(args.output_path, 'loader.zip'))
    torch.jit.script(decoder).save(
        os.path.join(args.output_path, 'decoder.zip'))
    scripted = torch.jit.script(encoder)
    if args.optimize_for_mobile:
        scripted = optimize_for_mobile(scripted)
    scripted.save(os.path.join(args.output_path, 'encoder.zip'))
Пример #5
0
 def extract_torchscript(self):
     """Extract torchscript module from image classification model"""
     super().print_message("torchscript")
     scripted_module = torch.jit.script(self.model)
     optimized_module = mobile_optimizer.optimize_for_mobile(
         scripted_module)
     optimized_module.save("{}.pt".format(self.file_name))
Пример #6
0
def trace_and_save_torchscript(
    model: torch.nn.Module,
    inputs: typing.Tuple[typing.Any, ...],
    output_path: str,
    use_get_traceable=False,
    trace_type="trace",
    opt_for_mobile=False,
):
    logger.info("Tracing and saving TorchScript to {} ...".format(output_path))

    with torch.no_grad():
        if use_get_traceable:
            model = ju.get_traceable_model(model)
        if trace_type == "trace":
            script_model = torch.jit.trace(model, inputs, strict=False)
        else:
            script_model = torch.jit.script(model)

    if opt_for_mobile:
        logger.info("Running optimize_for_mobile...")
        script_model = optimize_for_mobile(script_model)

    os.makedirs(output_path, exist_ok=True)

    model_file = os.path.join(output_path, "model.jit")
    script_model.save(model_file)

    data_file = os.path.join(output_path, "data.pth")
    torch.save(inputs, data_file)

    return model_file
Пример #7
0
def save_mobile_model():
    model.eval()
    example = torch.rand(1, 3, 480, 640).to(device)
    traced_script_module = torch.jit.trace(model, example)
    optimized_traced_model = optimize_for_mobile(traced_script_module)
    optimized_traced_model.save("mobile_model.pt")
    print('===> Mobile model saved!')
Пример #8
0
 def extract_torchscript(self):
     """Extract torchscript module from sentiment analysis model"""
     super().print_message("torchscript")
     scripted_module = torch.jit.trace(self.model, self.model_example)
     optimized_module = mobile_optimizer.optimize_for_mobile(
         scripted_module)
     optimized_module.save("{}.pt".format(self.file_name))
Пример #9
0
def export_torchscript(model,
                       im,
                       file,
                       optimize,
                       prefix=colorstr('TorchScript:')):
    #  TorchScript model export
    try:
        LOGGER.info(
            f'\n{prefix} starting export with torch {torch.__version__}...')
        f = file.with_suffix('.torchscript.pt')

        ts = torch.jit.trace(model, im, strict=False)
        d = {
            "shape": im.shape,
            "stride": int(max(model.stride)),
            "names": model.names
        }
        extra_files = {'config.txt': json.dumps(d)}  # torch._C.ExtraFilesMap()
        (optimize_for_mobile(ts) if optimize else ts).save(
            f, _extra_files=extra_files)

        LOGGER.info(
            f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
    except Exception as e:
        LOGGER.info(f'{prefix} export failure: {e}')
Пример #10
0
def test_export_torchvision_format():
    cfg_name = 'faster_rcnn_fbnetv3a_dsmask_C4.yaml'
    pytorch_model = model_zoo.get(cfg_name, trained=True)

    from typing import List, Dict

    class Wrapper(torch.nn.Module):
        def __init__(self, model):
            super().__init__()
            self.model = model
            coco_idx_list = [
                1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19,
                20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38,
                39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
                56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75,
                76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90, 91
            ]

            self.coco_idx = torch.tensor(coco_idx_list)

        def forward(self, inputs: List[torch.Tensor]):
            x = inputs[0].unsqueeze(0) * 255
            scale = 320.0 / min(x.shape[-2], x.shape[-1])
            x = torch.nn.functional.interpolate(x,
                                                scale_factor=scale,
                                                mode="bilinear",
                                                align_corners=True,
                                                recompute_scale_factor=True)
            out = self.model(x[0])
            res: Dict[str, torch.Tensor] = {}
            res["boxes"] = out[0] / scale
            res["labels"] = torch.index_select(self.coco_idx, 0, out[1])
            res["scores"] = out[2]
            return inputs, [res]

    size_divisibility = max(pytorch_model.backbone.size_divisibility, 10)
    h, w = size_divisibility, size_divisibility * 2

    runner = create_runner("d2go.runner.GeneralizedRCNNRunner")
    cfg = model_zoo.get_config(cfg_name)
    datasets = list(cfg.DATASETS.TRAIN)

    data_loader = runner.build_detection_test_loader(cfg, datasets)

    predictor_path = convert_and_export_predictor(
        cfg,
        copy.deepcopy(pytorch_model),
        "torchscript_int8@tracing",
        './',
        data_loader,
    )

    orig_model = torch.jit.load(os.path.join(predictor_path, "model.jit"))
    wrapped_model = Wrapper(orig_model)
    # optionally do a forward
    wrapped_model([torch.rand(3, 600, 600)])
    scripted_model = torch.jit.script(wrapped_model)
    optimized_model = optimize_for_mobile(scripted_model)
    optimized_model.save("D2Go/d2go_optimized.pt")
        def _benchmark_conv3d_pw_bn_relu_forward(**kwargs) -> Callable:
            assert kwargs["mode"] in ("original", "deployable"), (
                "kwargs['mode'] must be either 'original' or 'deployable',"
                "but got {}.".format(kwargs["mode"]))
            input_tensor = torch.randn((kwargs["input_blob_size"]))
            conv_block = Conv3dPwBnAct(
                kwargs["in_channels"],
                kwargs["out_channels"],
                use_bn=False,  # assume BN has already been fused for forward
            )

            if kwargs["mode"] == "deployable":
                conv_block.convert(kwargs["input_blob_size"])
            conv_block.eval()
            if kwargs["quantize"] is True:
                if kwargs["mode"] == "original":  # manually fuse conv and relu
                    conv_block.kernel = torch.quantization.fuse_modules(
                        conv_block.kernel, ["conv", "act.act"])
                conv_block = nn.Sequential(
                    torch.quantization.QuantStub(),
                    conv_block,
                    torch.quantization.DeQuantStub(),
                )

                conv_block.qconfig = torch.quantization.get_default_qconfig(
                    "qnnpack")
                conv_block = torch.quantization.prepare(conv_block)
                try:
                    conv_block = torch.quantization.convert(conv_block)
                except Exception as e:
                    logging.info(
                        "benchmark_conv3d_pw_bn_relu: "
                        "catch exception '{}' with kwargs of {}".format(
                            e, kwargs))

                    def func_to_benchmark_dummy() -> None:
                        return

                    return func_to_benchmark_dummy
            traced_model = torch.jit.trace(conv_block,
                                           input_tensor,
                                           strict=False)
            if kwargs["quantize"] is False:
                traced_model = optimize_for_mobile(traced_model)

            logging.info(f"model arch: {traced_model}")

            def func_to_benchmark() -> None:
                try:
                    _ = traced_model(input_tensor)
                except Exception as e:
                    logging.info(
                        "benchmark_conv3d_pw_bn_relu: "
                        "catch exception '{}' with kwargs of {}".format(
                            e, kwargs))

                return

            return func_to_benchmark
def script_and_serialize(model: nn.Module,
                         path: str,
                         opt_backend: Optional[str] = None):
    scripted_model = torch.jit.script(model)
    if opt_backend:
        scripted_model = optimize_for_mobile(script_module=scripted_model,
                                             backend=opt_backend)
    torch.jit.save(scripted_model, path)
        def _benchmark_x3d_bottleneck_forward(**kwargs) -> Callable:
            assert kwargs["mode"] in ("original", "deployable"), (
                "kwargs['mode'] must be either 'original' or 'deployable',"
                "but got {}.".format(kwargs["mode"]))
            input_tensor = torch.randn((kwargs["input_blob_size"]))
            conv_block = X3dBottleneckBlock(
                kwargs["in_channels"],
                kwargs["mid_channels"],
                kwargs["out_channels"],
                use_bn=(False, False,
                        False),  # Assume BN has been fused for forward
            )

            if kwargs["mode"] == "deployable":
                conv_block.convert(kwargs["input_blob_size"])
            conv_block.eval()
            if kwargs["quantize"] is True:
                conv_block = nn.Sequential(
                    torch.quantization.QuantStub(),
                    conv_block,
                    torch.quantization.DeQuantStub(),
                )

                conv_block.qconfig = torch.quantization.get_default_qconfig(
                    "qnnpack")
                conv_block = torch.quantization.prepare(conv_block)
                try:
                    conv_block = torch.quantization.convert(conv_block)
                except Exception as e:
                    logging.info(
                        "benchmark_x3d_bottleneck_forward: "
                        "catch exception '{}' with kwargs of {}".format(
                            e, kwargs))

                    def func_to_benchmark_dummy() -> None:
                        return

                    return func_to_benchmark_dummy

            traced_model = torch.jit.trace(conv_block,
                                           input_tensor,
                                           strict=False)
            if kwargs["quantize"] is False:
                traced_model = optimize_for_mobile(traced_model)

            logging.info(f"model arch: {traced_model}")

            def func_to_benchmark() -> None:
                try:
                    _ = traced_model(input_tensor)
                except Exception as e:
                    logging.info(
                        "benchmark_x3d_bottleneck_forward: "
                        "catch exception '{}' with kwargs of {}".format(
                            e, kwargs))
                return

            return func_to_benchmark
Пример #14
0
def trace_and_save_torchscript(
    model: nn.Module,
    inputs: Tuple[torch.Tensor],
    output_path: str,
    mobile_optimization: Optional[MobileOptimizationConfig] = None,
    _extra_files: Optional[Dict[str, bytes]] = None,
):
    logger.info("Tracing and saving TorchScript to {} ...".format(output_path))
    PathManager.mkdirs(output_path)
    if _extra_files is None:
        _extra_files = {}

    # TODO: patch_builtin_len depends on D2, we should either copy the function or
    # dynamically registering the D2's version.
    from detectron2.export.torchscript_patch import patch_builtin_len

    with torch.no_grad(), patch_builtin_len():
        script_model = torch.jit.trace(model, inputs)

    with make_temp_directory("trace_and_save_torchscript") as tmp_dir:

        @contextlib.contextmanager
        def _synced_local_file(rel_path):
            remote_file = os.path.join(output_path, rel_path)
            local_file = os.path.join(tmp_dir, rel_path)
            yield local_file
            PathManager.copy_from_local(local_file,
                                        remote_file,
                                        overwrite=True)

        with _synced_local_file("model.jit") as model_file:
            torch.jit.save(script_model, model_file, _extra_files=_extra_files)

        with _synced_local_file("data.pth") as data_file:
            torch.save(inputs, data_file)

        if mobile_optimization is not None:
            logger.info("Applying optimize_for_mobile ...")
            liteopt_model = optimize_for_mobile(
                script_model,
                optimization_blocklist=mobile_optimization.
                optimization_blocklist,
                preserved_methods=mobile_optimization.preserved_methods,
                backend=mobile_optimization.backend,
            )
            with _synced_local_file("mobile_optimized.ptl") as lite_path:
                liteopt_model._save_for_lite_interpreter(lite_path)
            # liteopt_model(*inputs)  # sanity check
            op_names = torch.jit.export_opnames(liteopt_model)
            logger.info("Operator names from lite interpreter:\n{}".format(
                "\n".join(op_names)))

            logger.info("Applying augment_model_with_bundled_inputs ...")
            augment_model_with_bundled_inputs(liteopt_model, [inputs])
            liteopt_model.run_on_bundled_input(0)  # sanity check
            with _synced_local_file(
                    "mobile_optimized_bundled.ptl") as lite_path:
                liteopt_model._save_for_lite_interpreter(lite_path)
Пример #15
0
 def _quant_script_and_optimize(model):
     model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
     model.fuse_model()
     torch.quantization.prepare(model, inplace=True)
     model(torch.randn(4, 1, 4, 4))
     torch.quantization.convert(model, inplace=True)
     model = torch.jit.script(model)
     model_optim = optimize_for_mobile(model)
     return model, model_optim
Пример #16
0
    def test_all_backport_functions(self):
        # Backport from the latest bytecode version to the minimum support version
        # Load, run the backport model, and check version
        class TestModule(torch.nn.Module):
            def __init__(self, v):
                super().__init__()
                self.x = v

            def forward(self, y: int):
                increment = torch.ones([2, 4], dtype=torch.float64)
                return self.x + y + increment

        module_input = 1
        expected_mobile_module_result = 3 * torch.ones([2, 4],
                                                       dtype=torch.float64)

        # temporary input model file and output model file will be exported in the temporary folder
        with tempfile.TemporaryDirectory() as tmpdirname:
            tmp_input_model_path = Path(tmpdirname, "tmp_script_module.ptl")
            script_module = torch.jit.script(TestModule(1))
            optimized_scripted_module = optimize_for_mobile(script_module)
            exported_optimized_scripted_module = optimized_scripted_module._save_for_lite_interpreter(
                str(tmp_input_model_path))

            current_from_version = _get_model_bytecode_version(
                tmp_input_model_path)
            current_to_version = current_from_version - 1
            tmp_output_model_path = Path(tmpdirname,
                                         "tmp_script_module_backport.ptl")

            while current_to_version >= MINIMUM_TO_VERSION:
                # Backport the latest model to `to_version` to a tmp file "tmp_script_module_backport"
                backport_success = _backport_for_mobile(
                    tmp_input_model_path, tmp_output_model_path,
                    current_to_version)
                assert (backport_success)

                backport_version = _get_model_bytecode_version(
                    tmp_output_model_path)
                assert (backport_version == current_to_version)

                # Load model and run forward method
                mobile_module = _load_for_lite_interpreter(
                    str(tmp_input_model_path))
                mobile_module_result = mobile_module(module_input)
                torch.testing.assert_allclose(mobile_module_result,
                                              expected_mobile_module_result)
                current_to_version -= 1

            # Check backport failure case
            backport_success = _backport_for_mobile(tmp_input_model_path,
                                                    tmp_output_model_path,
                                                    MINIMUM_TO_VERSION - 1)
            assert (not backport_success)
            # need to clean the folder before it closes, otherwise will run into git not clean error
            shutil.rmtree(tmpdirname)
Пример #17
0
    def test_mobilenet_optimize_for_mobile(self):
        m = torchvision.models.mobilenet_v3_small()
        m = torch.jit.script(m)
        m = optimize_for_mobile(m)

        # run forward 3 times until segfault, see https://github.com/pytorch/pytorch/issues/52463
        x = torch.zeros(1, 3, 56, 56)
        self.assertEqual(m(x).numel(), 1000)
        self.assertEqual(m(x).numel(), 1000)
        self.assertEqual(m(x).numel(), 1000)
Пример #18
0
 def initialize_model(self):
     if self._torchscript_model is None:
         self._model = self._lazy_model()
         Path("jit").mkdir(exist_ok=True, parents=True)
         self._torchscript_model = torch.jit.trace(self._model,
                                                   self.data.eval,
                                                   check_trace=False)
         self._torchscript_model = mobile_optimizer.optimize_for_mobile(
             self._torchscript_model)
         self._torchscript_model.save(f"jit/{self.name}.pt")
def trace_and_serialize(model: nn.Module,
                        example: torch.Tensor,
                        path: str,
                        opt_backend: Optional[str] = None):
    with torch.no_grad():
        traced_model = torch.jit.trace(model, example_inputs=example)
    if opt_backend:
        traced_model = optimize_for_mobile(script_module=traced_model,
                                           backend=opt_backend)
    torch.jit.save(traced_model, path)
Пример #20
0
def convert_to_mobile(model_name, path, dest):
    if model_name == 'concat':
        model = ConcatModel(2)
    elif model_name == 'vgg':
        model = Vgg16(2)
    else:
        model = ResNet50(2)
    model = convert(model, path)
    script_model = torch.jit.script(model)
    mobile_model = mobile_optimizer.optimize_for_mobile(script_model)
    torch.jit.save(mobile_model, dest)
Пример #21
0
    def validate_transform_conv1d_to_conv2d(self,
                                            pattern_count_transformed_map,
                                            pattern_count_optimized_map,
                                            data_shape):
        module_instance = self
        scripted_model = torch.jit.script(module_instance)
        scripted_model.eval()
        input_data = torch.normal(1, 20, size=data_shape)
        ref_result = scripted_model(input_data)
        torch._C._jit_pass_transform_conv1d_to_conv2d(scripted_model._c)
        optimized_scripted_model = optimize_for_mobile(scripted_model)

        buffer = io.BytesIO()
        torch.jit.save(scripted_model, buffer)
        buffer.seek(0)
        deserialized_scripted_model = torch.jit.load(buffer)

        for pattern, v in pattern_count_transformed_map.items():
            if (v == 0):
                FileCheck().check(pattern).run(
                    deserialized_scripted_model.graph)
            elif (v == -1):
                FileCheck().check_not(pattern).run(
                    deserialized_scripted_model.graph)
            else:
                FileCheck().check_count(pattern, v, exactly=True).run(
                    deserialized_scripted_model.graph)
        transformed_result = deserialized_scripted_model(input_data)
        torch.testing.assert_allclose(ref_result,
                                      transformed_result,
                                      rtol=1e-2,
                                      atol=1e-3)

        optimized_buffer = io.BytesIO()
        torch.jit.save(optimized_scripted_model, optimized_buffer)
        optimized_buffer.seek(0)
        deserialized_optimized_scripted_model = torch.jit.load(
            optimized_buffer)

        for pattern, v in pattern_count_optimized_map.items():
            if (v == 0):
                FileCheck().check(pattern).run(
                    deserialized_optimized_scripted_model.graph)
            elif (v == -1):
                FileCheck().check_not(pattern).run(
                    deserialized_optimized_scripted_model.graph)
            else:
                FileCheck().check_count(pattern, v, exactly=True).run(
                    deserialized_optimized_scripted_model.graph)
        xnnpack_result = deserialized_optimized_scripted_model(input_data)
        torch.testing.assert_allclose(ref_result,
                                      xnnpack_result,
                                      rtol=1e-2,
                                      atol=1e-3)
Пример #22
0
def export_torchscript(model, img, file, optimize):
    # TorchScript model export
    prefix = colorstr('TorchScript:')
    try:
        print(f'\n{prefix} starting export with torch {torch.__version__}...')
        f = file.with_suffix('.torchscript.pt')
        ts = torch.jit.trace(model, img, strict=False)
        (optimize_for_mobile(ts) if optimize else ts).save(f)
        print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
        return ts
    except Exception as e:
        print(f'{prefix} export failure: {e}')
Пример #23
0
    def save_model_for_mobile(self,
                              width: int,
                              height: int,
                              out_filepath: str,
                              for_os="cpu"):
        torch_model = self.to("cpu")
        torch_model.eval()

        if for_os == "cpu":
            example = torch.rand(1, 3, height, width).to("cpu")
            traced_script_module = torch.jit.trace(torch_model, example)
            traced_script_module.save(out_filepath)
            return

        script_model = torch.jit.script(torch_model)
        if for_os == "android":
            mobile_optimizer.optimize_for_mobile(script_model,
                                                 backend="Vulkan")
        elif for_os == "ios":
            mobile_optimizer.optimize_for_mobile(script_model, backend="metal")
        torch.jit.save(script_model, out_filepath)
Пример #24
0
 def getModule(self):
     model = torchvision.models.mobilenet_v2(pretrained=True)
     model.eval()
     example = torch.zeros(1, 3, 224, 224)
     traced_script_module = torch.jit.trace(model, example)
     optimized_module = optimize_for_mobile(traced_script_module)
     augment_model_with_bundled_inputs(
         optimized_module,
         [
             (example, ),
         ],
     )
     optimized_module(example)
     return optimized_module
Пример #25
0
    def test_save_mobile_module_with_debug_info_with_script_nested_call(self):
        class A(torch.nn.Module):
            def __init__(self):
                super(A, self).__init__()

            def forward(self, x):
                return x + 1

        class B(torch.nn.Module):
            def __init__(self):
                super(B, self).__init__()

            def forward(self, x):
                return x + 2

        class C(torch.nn.Module):
            def __init__(self):
                super(C, self).__init__()
                self.A0 = A()
                self.B0 = B()

            def forward(self, x):
                return self.A0(self.B0(x)) + 1

        input = torch.tensor([5])
        scripted_module = torch.jit.script(C(), input)

        optimized_scripted_module = optimize_for_mobile(scripted_module)

        exported_module = scripted_module._save_to_buffer_for_lite_interpreter(
            _save_mobile_debug_info=True)
        optimized_exported_module = optimized_scripted_module._save_to_buffer_for_lite_interpreter(
            _save_mobile_debug_info=True)
        assert (b"mobile_debug.pkl" in exported_module)
        assert (b"module_debug_info" in exported_module)
        assert (b"top(C).forward" in exported_module)
        assert (b"top(C).A0(A).forward" in exported_module)
        assert (b"top(C).B0(B).forward" in exported_module)

        assert (b"mobile_debug.pkl" in optimized_exported_module)
        assert (b"module_debug_info" in optimized_exported_module)
        assert (b"top(C).forward" in optimized_exported_module)
        assert (b"top(C).A0(A).forward" in optimized_exported_module)
        assert (b"top(C).B0(B).forward" in optimized_exported_module)
Пример #26
0
    def test_quantized_conv_no_asan_failures(self):
        # There were ASAN failures when fold_conv_bn was run on
        # already quantized conv modules. Verifying that this does
        # not happen again.

        if 'qnnpack' not in torch.backends.quantized.supported_engines:
            return

        class Child(nn.Module):
            def __init__(self):
                super(Child, self).__init__()
                self.conv2 = nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.conv2(x)
                return x

        class Parent(nn.Module):
            def __init__(self):
                super(Parent, self).__init__()
                self.quant = torch.ao.quantization.QuantStub()
                self.conv1 = nn.Conv2d(1, 1, 1)
                self.child = Child()
                self.dequant = torch.ao.quantization.DeQuantStub()

            def forward(self, x):
                x = self.quant(x)
                x = self.conv1(x)
                x = self.child(x)
                x = self.dequant(x)
                return x

        with override_quantized_engine('qnnpack'):
            model = Parent()
            model.qconfig = torch.ao.quantization.get_default_qconfig(
                'qnnpack')
            torch.ao.quantization.prepare(model, inplace=True)
            model(torch.randn(4, 1, 4, 4))
            torch.ao.quantization.convert(model, inplace=True)
            model = torch.jit.script(model)
            # this line should not have ASAN failures
            model_optim = optimize_for_mobile(model)
Пример #27
0
def run(
        weights='./yolov5s.pt',  # weights path
        img_size=(640, 640),  # image (height, width)
        batch_size=1,  # batch size
        device='cpu',  # cuda device, i.e. 0 or 0,1,2,3 or cpu
        include=('torchscript', 'onnx', 'coreml'),  # include formats
        half=False,  # FP16 half-precision export
        inplace=False,  # set YOLOv5 Detect() inplace=True
        train=False,  # model.train() mode
        optimize=False,  # TorchScript: optimize for mobile
        dynamic=False,  # ONNX: dynamic axes
        simplify=False,  # ONNX: simplify model
        opset_version=12,  # ONNX: opset version
):
    t = time.time()
    include = [x.lower() for x in include]
    img_size *= 2 if len(img_size) == 1 else 1  # expand

    # Load PyTorch model
    device = select_device(device)
    assert not (
        device.type == 'cpu' and opt.half
    ), '--half only compatible with GPU export, i.e. use --device 0'
    model = attempt_load(weights, map_location=device)  # load FP32 model
    labels = model.names

    # Input
    gs = int(max(model.stride))  # grid size (max stride)
    img_size = [check_img_size(x, gs)
                for x in img_size]  # verify img_size are gs-multiples
    img = torch.zeros(batch_size, 3, *img_size).to(
        device)  # image size(1,3,320,192) iDetection

    # Update model
    if half:
        img, model = img.half(), model.half()  # to FP16
    model.train() if train else model.eval(
    )  # training mode = no Detect() layer grid construction
    for k, m in model.named_modules():
        m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatibility
        if isinstance(m, Conv):  # assign export-friendly activations
            if isinstance(m.act, nn.Hardswish):
                m.act = Hardswish()
            elif isinstance(m.act, nn.SiLU):
                m.act = SiLU()
        elif isinstance(m, Detect):
            m.inplace = inplace
            m.onnx_dynamic = dynamic
            # m.forward = m.forward_export  # assign forward (optional)

    for _ in range(2):
        y = model(img)  # dry runs
    print(
        f"\n{colorstr('PyTorch:')} starting from {weights} ({file_size(weights):.1f} MB)"
    )

    # TorchScript export -----------------------------------------------------------------------------------------------
    if 'torchscript' in include or 'coreml' in include:
        prefix = colorstr('TorchScript:')
        try:
            print(
                f'\n{prefix} starting export with torch {torch.__version__}...'
            )
            f = weights.replace('.pt', '.torchscript.pt')  # filename
            ts = torch.jit.trace(model, img, strict=False)
            (optimize_for_mobile(ts) if optimize else ts).save(f)
            print(
                f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)'
            )
        except Exception as e:
            print(f'{prefix} export failure: {e}')

    # ONNX export ------------------------------------------------------------------------------------------------------
    if 'onnx' in include:
        prefix = colorstr('ONNX:')
        try:
            import onnx

            print(f'{prefix} starting export with onnx {onnx.__version__}...')
            f = weights.replace('.pt', '.onnx')  # filename
            torch.onnx.export(
                model,
                img,
                f,
                verbose=False,
                opset_version=opset_version,
                training=torch.onnx.TrainingMode.TRAINING
                if train else torch.onnx.TrainingMode.EVAL,
                do_constant_folding=not train,
                input_names=['images'],
                output_names=['output'],
                dynamic_axes={
                    'images': {
                        0: 'batch',
                        2: 'height',
                        3: 'width'
                    },  # shape(1,3,640,640)
                    'output': {
                        0: 'batch',
                        1: 'anchors'
                    }  # shape(1,25200,85)
                } if dynamic else None)

            # Checks
            model_onnx = onnx.load(f)  # load onnx model
            onnx.checker.check_model(model_onnx)  # check onnx model
            # print(onnx.helper.printable_graph(model_onnx.graph))  # print

            # Simplify
            if simplify:
                try:
                    check_requirements(['onnx-simplifier'])
                    import onnxsim

                    print(
                        f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...'
                    )
                    model_onnx, check = onnxsim.simplify(
                        model_onnx,
                        dynamic_input_shape=dynamic,
                        input_shapes={'images': list(img.shape)}
                        if dynamic else None)
                    assert check, 'assert check failed'
                    onnx.save(model_onnx, f)
                except Exception as e:
                    print(f'{prefix} simplifier failure: {e}')
            print(
                f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)'
            )
        except Exception as e:
            print(f'{prefix} export failure: {e}')

    # CoreML export ----------------------------------------------------------------------------------------------------
    if 'coreml' in include:
        prefix = colorstr('CoreML:')
        try:
            import coremltools as ct

            print(
                f'{prefix} starting export with coremltools {ct.__version__}...'
            )
            assert train, 'CoreML exports should be placed in model.train() mode with `python export.py --train`'
            model = ct.convert(ts,
                               inputs=[
                                   ct.ImageType('image',
                                                shape=img.shape,
                                                scale=1 / 255.0,
                                                bias=[0, 0, 0])
                               ])
            f = weights.replace('.pt', '.mlmodel')  # filename
            model.save(f)
            print(
                f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)'
            )
        except Exception as e:
            print(f'{prefix} export failure: {e}')

    # Finish
    print(
        f'\nExport complete ({time.time() - t:.2f}s). Visualize with https://github.com/lutzroeder/netron.'
    )
Пример #28
0
    def test_optimize_for_mobile(self):
        batch_size = 2
        input_channels_per_group = 6
        height = 16
        width = 16
        output_channels_per_group = 6
        groups = 4
        kernel_h = kernel_w = 3
        stride_h = stride_w = 1
        pad_h = pad_w = 1
        dilation = 1
        input_channels = input_channels_per_group * groups
        output_channels = output_channels_per_group * groups
        kernels = (kernel_h, kernel_w)
        strides = (stride_h, stride_w)
        paddings = (pad_h, pad_w)
        dilations = (dilation, dilation)
        conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
        conv_bias_shape = (output_channels)

        input_data = torch.rand((batch_size, input_channels, height, width))
        conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
        conv_bias = torch.rand((output_channels))
        result = F.conv2d(input_data, conv_weight, conv_bias, strides, paddings, dilations, groups)
        weight_output_dim = 24
        linear_input_shape = result.shape[1]
        linear_weight_shape = (weight_output_dim, linear_input_shape)

        class MyTestModule(torch.nn.Module):
            def __init__(self):
                super(MyTestModule, self).__init__()
                self.conv_weight = torch.nn.Parameter(torch.rand(conv_weight_shape))
                self.conv_bias = torch.nn.Parameter(torch.rand((conv_bias_shape)))
                self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape))
                self.linear_bias = torch.nn.Parameter(torch.rand((weight_output_dim)))
                self.strides = strides
                self.paddings = paddings
                self.dilations = dilations
                self.groups = groups

            def forward(self, x):
                o = F.conv2d(x, self.conv_weight, self.conv_bias,
                             self.strides, self.paddings, self.dilations, self.groups)
                o = F.relu(o)
                x = o.permute([0, 2, 3, 1])
                o = F.linear(x, self.linear_weight, self.linear_bias)
                o = o + x
                return F.relu(o)

            @torch.jit.export
            def foo(self, x):
                o = F.conv2d(x, self.conv_weight, self.conv_bias,
                             self.strides, self.paddings, self.dilations, self.groups)
                o = F.relu(o)
                x = o.permute([0, 2, 3, 1])
                o = F.linear(x, self.linear_weight, self.linear_bias)
                o = o + x
                return F.relu(o)


        class BNTestModule(torch.nn.Module):
            def __init__(self):
                super(BNTestModule, self).__init__()
                self.conv = torch.nn.Conv2d(1, 20, 5, 1)
                self.bn = torch.nn.BatchNorm2d(num_features=20)
                self.bn.eps = 0.0023

            def forward(self, x):
                x = self.conv(x)
                x = self.bn(x)
                return x

        data_shape = (batch_size, input_channels, height, width)
        input_data = torch.normal(1, 20, size=data_shape)

        scripted_model = torch.jit.script(MyTestModule())
        scripted_model.eval()
        initial_result = scripted_model(input_data)
        initial_foo_result = scripted_model.foo(input_data)

        optimized_scripted_model = optimize_for_mobile(scripted_model, preserved_methods=['foo'])
        optimized_result = optimized_scripted_model(input_data)
        optimized_foo_result = optimized_scripted_model.foo(input_data)

        FileCheck().check_not("Tensor = aten::conv2d") \
                   .check_not("Tensor = prim::CallFunction") \
                   .check_not("prepacked::conv2d_clamp_prepack") \
                   .check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \
                   .check_not("prepacked::linear_clamp_prepack") \
                   .check_count("prepacked::linear_clamp_run", 1, exactly=True) \
                   .check_not("aten::add(") \
                   .check_not("aten::relu(") \
                   .check_count("aten::_add_relu(", 1, exactly=True) \
                   .run(optimized_scripted_model.graph)
        torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3)

        FileCheck().check_not("Tensor = aten::conv2d") \
                   .check_not("Tensor = prim::CallFunction") \
                   .check_not("prepacked::conv2d_clamp_prepack") \
                   .check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \
                   .check_not("prepacked::linear_clamp_prepack") \
                   .check_count("prepacked::linear_clamp_run", 1, exactly=True) \
                   .check_not("aten::add(") \
                   .check_not("aten::relu(") \
                   .check_count("aten::_add_relu(", 1, exactly=True) \
                   .run(optimized_scripted_model.foo.graph)
        torch.testing.assert_allclose(initial_foo_result, optimized_foo_result, rtol=1e-2, atol=1e-3)


        optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
        optimized_scripted_model_no_prepack = optimize_for_mobile(scripted_model, optimization_blocklist_no_prepack)
        optimized_result_no_prepack = optimized_scripted_model_no_prepack(input_data)

        FileCheck().check_count("Tensor = aten::conv2d", 1, exactly=True) \
                   .check_not("prepacked::linear_clamp_run") \
                   .check_not("prepacked::conv2d_clamp_run") \
                   .run(optimized_scripted_model_no_prepack.graph)
        torch.testing.assert_allclose(initial_result, optimized_result_no_prepack, rtol=1e-2, atol=1e-3)


        bn_test_module = BNTestModule()
        bn_scripted_module = torch.jit.script(bn_test_module)
        bn_scripted_module.eval()

        self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 14)
        FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
                   .run(str(get_forward(bn_scripted_module._c).graph))

        optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
        bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_prepack)
        self.assertEqual(len(torch.jit.export_opnames(bn_fold_scripted_module)), 1)
        bn_input = torch.rand(1, 1, 6, 6)
        torch.testing.assert_allclose(bn_scripted_module(bn_input), bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)

        optimization_blocklist_no_fold_bn = {MobileOptimizerType.CONV_BN_FUSION}
        no_bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_fold_bn)
        FileCheck().check_count("aten::batch_norm", 1, exactly=True) \
                   .run(str(get_forward_graph(no_bn_fold_scripted_module._c)))
        bn_input = torch.rand(1, 1, 6, 6)
        torch.testing.assert_allclose(bn_scripted_module(bn_input), no_bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)

        class MyMobileOptimizedTagTest(torch.nn.Module):
            def __init__(self):
                super(MyMobileOptimizedTagTest, self).__init__()
                self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape))
                self.linear_bias = torch.nn.Parameter(torch.rand((weight_output_dim)))

            def forward(self, x):
                o = F.linear(x, self.linear_weight, self.linear_bias)
                return F.relu(o)

        mobile_optimized_tag_module = MyMobileOptimizedTagTest()
        m = torch.jit.script(mobile_optimized_tag_module)
        m.eval()
        opt_m = optimize_for_mobile(m)
        tag = getattr(opt_m, "mobile_optimized", None)
        self.assertTrue(tag)

        class MyPreserveMethodsTest(torch.nn.Module):
            def __init__(self):
                super(MyPreserveMethodsTest, self).__init__()
                self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape))
                self.linear_bias = torch.nn.Parameter(torch.rand((weight_output_dim)))

            def forward(self, x):
                o = F.linear(x, self.linear_weight, self.linear_bias)
                return F.relu(o)

            @torch.jit.export
            def preserveThis(self):
                pass

        preserve_method_module = MyPreserveMethodsTest()
        m = torch.jit.script(preserve_method_module)
        m.eval()
        opt_m = optimize_for_mobile(m)
        no_preserveThis = getattr(opt_m, "preserveThis", None)
        self.assertEqual(no_preserveThis, None)
        opt_m = optimize_for_mobile(m, preserved_methods=["preserveThis"])
        preserveThis = getattr(opt_m, "preserveThis", None)
        self.assertNotEqual(preserveThis, None)

        class OptimizeNoForwardTest(torch.nn.Module):
            def __init__(self):
                super(OptimizeNoForwardTest, self).__init__()
                self.l = nn.Linear(10, 100)
                self.l2 = nn.Linear(100, 1)
                self.d = nn.Dropout(p=0.2)

            @torch.jit.export
            def foo(self, x):
                x = self.d(F.relu(self.l(x)))
                x = self.l2(x)
                x = x + torch.ones(1, 100)
                return F.relu(x)
        input_data = torch.ones(1, 10)
        m = torch.jit.script(OptimizeNoForwardTest())
        m.eval()
        initial_result = m.foo(input_data)

        optimized_scripted_model = optimize_for_mobile(m, preserved_methods=['foo'])
        optimized_result = optimized_scripted_model.foo(input_data)

        FileCheck().check_not("dropout.__") \
            .check_count("aten::_add_relu(", 1, exactly=True) \
            .run(optimized_scripted_model.foo.graph)
        torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3)

        class BNTestNoForwardModule(torch.nn.Module):
            def __init__(self):
                super(BNTestNoForwardModule, self).__init__()
                self.conv = torch.nn.Conv2d(1, 20, 5, 1)
                self.bn = torch.nn.BatchNorm2d(num_features=20)
                self.bn.eps = 0.0023

            @torch.jit.export
            def foo(self, x):
                x = self.conv(x)
                x = self.bn(x)
                return x

        bn_test_no_forward_module = BNTestNoForwardModule()
        bn_no_forward_scripted_module = torch.jit.script(bn_test_no_forward_module)
        bn_no_forward_scripted_module.eval()

        self.assertEqual(len(torch.jit.export_opnames(bn_no_forward_scripted_module)), 14)
        FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
                   .run(bn_no_forward_scripted_module.foo.graph)

        bn_fold_no_foward_scripted_module = optimize_for_mobile(bn_no_forward_scripted_module, preserved_methods=['foo'])
        self.assertEqual(len(torch.jit.export_opnames(bn_fold_no_foward_scripted_module)), 1)
        bn_input = torch.rand(1, 1, 6, 6)
        torch.testing.assert_allclose(
            bn_no_forward_scripted_module.foo(bn_input),
            bn_fold_no_foward_scripted_module.foo(bn_input),
            rtol=1e-2,
            atol=1e-3)
Пример #29
0
import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile

model = torchvision.models.mobilenet_v2(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_scripted_module = optimize_for_mobile(traced_script_module)
torch.jit.save(optimized_scripted_module, '../models/model.pt')
exported_optimized_scripted_module = optimized_scripted_module._save_for_lite_interpreter(
    "../models/model_lite.ptl")
Пример #30
0
def trace_and_save_torchscript(
    model: nn.Module,
    inputs: Tuple[torch.Tensor],
    output_path: str,
    torchscript_filename: str = "model.jit",
    mobile_optimization: Optional[MobileOptimizationConfig] = None,
    _extra_files: Optional[Dict[str, bytes]] = None,
):
    logger.info("Tracing and saving TorchScript to {} ...".format(output_path))
    PathManager.mkdirs(output_path)
    if _extra_files is None:
        _extra_files = {}

    with torch.no_grad():
        script_model = torch.jit.trace(model, inputs)

    with make_temp_directory("trace_and_save_torchscript") as tmp_dir:

        @contextlib.contextmanager
        def _synced_local_file(rel_path):
            remote_file = os.path.join(output_path, rel_path)
            local_file = os.path.join(tmp_dir, rel_path)
            yield local_file
            PathManager.copy_from_local(local_file,
                                        remote_file,
                                        overwrite=True)

        with _synced_local_file(torchscript_filename) as model_file:
            torch.jit.save(script_model, model_file, _extra_files=_extra_files)

        with _synced_local_file("data.pth") as data_file:
            torch.save(inputs, data_file)

        if mobile_optimization is not None:
            logger.info("Applying optimize_for_mobile ...")
            liteopt_model = optimize_for_mobile(
                script_model,
                optimization_blocklist=mobile_optimization.
                optimization_blocklist,
                preserved_methods=mobile_optimization.preserved_methods,
                backend=mobile_optimization.backend,
            )
            torchscript_filename = mobile_optimization.torchscript_filename
            with _synced_local_file(torchscript_filename) as lite_path:
                liteopt_model._save_for_lite_interpreter(
                    lite_path, _extra_files=_extra_files)
            # liteopt_model(*inputs)  # sanity check
            op_names = torch.jit.export_opnames(liteopt_model)
            logger.info("Operator names from lite interpreter:\n{}".format(
                "\n".join(op_names)))

            logger.info("Applying augment_model_with_bundled_inputs ...")
            # make all tensors zero-like to save storage
            iters = recursive_iterate(inputs)
            for x in iters:
                if isinstance(x, torch.Tensor):
                    iters.send(torch.zeros_like(x).contiguous())
            inputs = iters.value
            augment_model_with_bundled_inputs(liteopt_model, [inputs])
            liteopt_model(
                *liteopt_model.get_all_bundled_inputs()[0])  # sanity check
            name, ext = os.path.splitext(torchscript_filename)
            with _synced_local_file(name + "_bundled" + ext) as lite_path:
                liteopt_model._save_for_lite_interpreter(lite_path)

        return torchscript_filename