def test_all_backport_functions(self): # Backport from the latest bytecode version to the minimum support version # Load, run the backport model, and check version class TestModule(torch.nn.Module): def __init__(self, v): super().__init__() self.x = v def forward(self, y: int): increment = torch.ones([2, 4], dtype=torch.float64) return self.x + y + increment module_input = 1 expected_mobile_module_result = 3 * torch.ones([2, 4], dtype=torch.float64) # temporary input model file and output model file will be exported in the temporary folder with tempfile.TemporaryDirectory() as tmpdirname: tmp_input_model_path = Path(tmpdirname, "tmp_script_module.ptl") script_module = torch.jit.script(TestModule(1)) optimized_scripted_module = optimize_for_mobile(script_module) exported_optimized_scripted_module = optimized_scripted_module._save_for_lite_interpreter( str(tmp_input_model_path)) current_from_version = _get_model_bytecode_version( tmp_input_model_path) current_to_version = current_from_version - 1 tmp_output_model_path = Path(tmpdirname, "tmp_script_module_backport.ptl") while current_to_version >= MINIMUM_TO_VERSION: # Backport the latest model to `to_version` to a tmp file "tmp_script_module_backport" backport_success = _backport_for_mobile( tmp_input_model_path, tmp_output_model_path, current_to_version) assert (backport_success) backport_version = _get_model_bytecode_version( tmp_output_model_path) assert (backport_version == current_to_version) # Load model and run forward method mobile_module = _load_for_lite_interpreter( str(tmp_input_model_path)) mobile_module_result = mobile_module(module_input) torch.testing.assert_allclose(mobile_module_result, expected_mobile_module_result) current_to_version -= 1 # Check backport failure case backport_success = _backport_for_mobile(tmp_input_model_path, tmp_output_model_path, MINIMUM_TO_VERSION - 1) assert (not backport_success) # need to clean the folder before it closes, otherwise will run into git not clean error shutil.rmtree(tmpdirname)
def test_backport_bytecode_from_file_to_file(self): maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys()) script_module_v5_path = pytorch_test_dir / "cpp" / "jit" / SCRIPT_MODULE_BYTECODE_PKL[ maximum_checked_in_model_version]["model_name"] if (maximum_checked_in_model_version > MINIMUM_TO_VERSION): with tempfile.TemporaryDirectory() as tmpdirname: tmp_backport_model_path = Path(tmpdirname, "tmp_script_module_v5_backported_to_v4.ptl") # backport from file success = _backport_for_mobile( script_module_v5_path, tmp_backport_model_path, maximum_checked_in_model_version - 1) assert(success) buf = io.StringIO() torch.utils.show_pickle.main( ["", tmpdirname + "/" + tmp_backport_model_path.name + "@*/bytecode.pkl"], output_stream=buf) output = buf.getvalue() expected_result = SCRIPT_MODULE_V4_BYTECODE_PKL acutal_result_clean = "".join(output.split()) expect_result_clean = "".join(expected_result.split()) isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean) assert(isMatch) # Load model v4 and run forward method mobile_module = _load_for_lite_interpreter(str(tmp_backport_model_path)) module_input = 1 mobile_module_result = mobile_module(module_input) expected_mobile_module_result = 3 * torch.ones([2, 4], dtype=torch.float64) torch.testing.assert_close(mobile_module_result, expected_mobile_module_result) shutil.rmtree(tmpdirname)
def test_bytecode_values_for_all_backport_functions(self): # Find the maximum version of the checked in models, start backporting to the minimum support version, # and comparing the bytecode pkl content. # It can't be merged to the test `test_all_backport_functions`, because optimization is dynamic and # the content might change when optimize function changes. This test focuses # on bytecode.pkl content validation. For the content validation, it is not byte to byte check, but # regular expression matching. The wildcard can be used to skip some specific content comparison. maximum_checked_in_model_version = max( SCRIPT_MODULE_BYTECODE_PKL.keys()) current_from_version = maximum_checked_in_model_version with tempfile.TemporaryDirectory() as tmpdirname: while current_from_version > MINIMUM_TO_VERSION: # Load model v5 and run forward method model_name = SCRIPT_MODULE_BYTECODE_PKL[current_from_version][ "model_name"] input_model_path = pytorch_test_dri / "cpp" / "jit" / model_name # A temporary model file will be export to this path, and run through bytecode.pkl # content check. tmp_output_model_path_backport = Path( tmpdirname, "tmp_script_module_backport.ptl") current_to_version = current_from_version - 1 backport_success = _backport_for_mobile( input_model_path, tmp_output_model_path_backport, current_to_version) assert (backport_success) expect_bytecode_pkl = SCRIPT_MODULE_BYTECODE_PKL[ current_to_version]["bytecode_pkl"] buf = io.StringIO() torch.utils.show_pickle.main([ "", tmpdirname + "/" + tmp_output_model_path_backport.name + "@*/bytecode.pkl" ], output_stream=buf) output = buf.getvalue() acutal_result_clean = "".join(output.split()) expect_result_clean = "".join(expect_bytecode_pkl.split()) isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean) assert (isMatch) current_from_version -= 1 shutil.rmtree(tmpdirname)