예제 #1
0
    def test_all_backport_functions(self):
        # Backport from the latest bytecode version to the minimum support version
        # Load, run the backport model, and check version
        class TestModule(torch.nn.Module):
            def __init__(self, v):
                super().__init__()
                self.x = v

            def forward(self, y: int):
                increment = torch.ones([2, 4], dtype=torch.float64)
                return self.x + y + increment

        module_input = 1
        expected_mobile_module_result = 3 * torch.ones([2, 4],
                                                       dtype=torch.float64)

        # temporary input model file and output model file will be exported in the temporary folder
        with tempfile.TemporaryDirectory() as tmpdirname:
            tmp_input_model_path = Path(tmpdirname, "tmp_script_module.ptl")
            script_module = torch.jit.script(TestModule(1))
            optimized_scripted_module = optimize_for_mobile(script_module)
            exported_optimized_scripted_module = optimized_scripted_module._save_for_lite_interpreter(
                str(tmp_input_model_path))

            current_from_version = _get_model_bytecode_version(
                tmp_input_model_path)
            current_to_version = current_from_version - 1
            tmp_output_model_path = Path(tmpdirname,
                                         "tmp_script_module_backport.ptl")

            while current_to_version >= MINIMUM_TO_VERSION:
                # Backport the latest model to `to_version` to a tmp file "tmp_script_module_backport"
                backport_success = _backport_for_mobile(
                    tmp_input_model_path, tmp_output_model_path,
                    current_to_version)
                assert (backport_success)

                backport_version = _get_model_bytecode_version(
                    tmp_output_model_path)
                assert (backport_version == current_to_version)

                # Load model and run forward method
                mobile_module = _load_for_lite_interpreter(
                    str(tmp_input_model_path))
                mobile_module_result = mobile_module(module_input)
                torch.testing.assert_allclose(mobile_module_result,
                                              expected_mobile_module_result)
                current_to_version -= 1

            # Check backport failure case
            backport_success = _backport_for_mobile(tmp_input_model_path,
                                                    tmp_output_model_path,
                                                    MINIMUM_TO_VERSION - 1)
            assert (not backport_success)
            # need to clean the folder before it closes, otherwise will run into git not clean error
            shutil.rmtree(tmpdirname)
예제 #2
0
    def test_backport_bytecode_from_file_to_file(self):
        maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys())
        script_module_v5_path = pytorch_test_dir / "cpp" / "jit" / SCRIPT_MODULE_BYTECODE_PKL[
            maximum_checked_in_model_version]["model_name"]

        if (maximum_checked_in_model_version > MINIMUM_TO_VERSION):
            with tempfile.TemporaryDirectory() as tmpdirname:
                tmp_backport_model_path = Path(tmpdirname, "tmp_script_module_v5_backported_to_v4.ptl")
                # backport from file
                success = _backport_for_mobile(
                    script_module_v5_path,
                    tmp_backport_model_path,
                    maximum_checked_in_model_version - 1)
                assert(success)

                buf = io.StringIO()
                torch.utils.show_pickle.main(
                    ["", tmpdirname + "/" + tmp_backport_model_path.name + "@*/bytecode.pkl"],
                    output_stream=buf)
                output = buf.getvalue()

                expected_result = SCRIPT_MODULE_V4_BYTECODE_PKL
                acutal_result_clean = "".join(output.split())
                expect_result_clean = "".join(expected_result.split())
                isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean)
                assert(isMatch)

                # Load model v4 and run forward method
                mobile_module = _load_for_lite_interpreter(str(tmp_backport_model_path))
                module_input = 1
                mobile_module_result = mobile_module(module_input)
                expected_mobile_module_result = 3 * torch.ones([2, 4], dtype=torch.float64)
                torch.testing.assert_close(mobile_module_result, expected_mobile_module_result)
                shutil.rmtree(tmpdirname)
예제 #3
0
    def test_bytecode_values_for_all_backport_functions(self):
        # Find the maximum version of the checked in models, start backporting to the minimum support version,
        # and comparing the bytecode pkl content.
        # It can't be merged to the test `test_all_backport_functions`, because optimization is dynamic and
        # the content might change when optimize function changes. This test focuses
        # on bytecode.pkl content validation. For the content validation, it is not byte to byte check, but
        # regular expression matching. The wildcard can be used to skip some specific content comparison.
        maximum_checked_in_model_version = max(
            SCRIPT_MODULE_BYTECODE_PKL.keys())
        current_from_version = maximum_checked_in_model_version

        with tempfile.TemporaryDirectory() as tmpdirname:
            while current_from_version > MINIMUM_TO_VERSION:
                # Load model v5 and run forward method
                model_name = SCRIPT_MODULE_BYTECODE_PKL[current_from_version][
                    "model_name"]
                input_model_path = pytorch_test_dri / "cpp" / "jit" / model_name

                # A temporary model file will be export to this path, and run through bytecode.pkl
                # content check.
                tmp_output_model_path_backport = Path(
                    tmpdirname, "tmp_script_module_backport.ptl")

                current_to_version = current_from_version - 1
                backport_success = _backport_for_mobile(
                    input_model_path, tmp_output_model_path_backport,
                    current_to_version)
                assert (backport_success)

                expect_bytecode_pkl = SCRIPT_MODULE_BYTECODE_PKL[
                    current_to_version]["bytecode_pkl"]

                buf = io.StringIO()
                torch.utils.show_pickle.main([
                    "", tmpdirname + "/" +
                    tmp_output_model_path_backport.name + "@*/bytecode.pkl"
                ],
                                             output_stream=buf)
                output = buf.getvalue()

                acutal_result_clean = "".join(output.split())
                expect_result_clean = "".join(expect_bytecode_pkl.split())
                isMatch = fnmatch.fnmatch(acutal_result_clean,
                                          expect_result_clean)
                assert (isMatch)

                current_from_version -= 1
            shutil.rmtree(tmpdirname)