def test_codegen(self): with open(self.enclosing_dir / 'mnist_pytorch.json') as f: model = Model._load(json.load(f)) script = model_to_pytorch_script(model) with open(self.enclosing_dir / 'debug_mnist_pytorch.py') as f: reference_script = f.read() self.assertEqual(script.strip(), reference_script.strip())
def test_mnist_example_pytorch(self): with open('mnist_pytorch.json') as f: model = Model._load(json.load(f)) script = model_to_pytorch_script(model) with open('debug_mnist_pytorch.py') as f: reference_script = f.read() self.assertEqual(script.strip(), reference_script.strip())
def run_test(self, model, input, check_value=True): model_ir = self._convert_model(model, input) model_code = model_to_pytorch_script(model_ir) print(model_code) from .inject_nn import remove_inject_pytorch_nn remove_inject_pytorch_nn() exec_vars = {} exec(model_code + '\n\nconverted_model = _model()', exec_vars) converted_model = exec_vars['converted_model'] converted_state_dict = self._match_state_dict( list(model.state_dict().values()), dict(converted_model.state_dict())) converted_model.load_state_dict(converted_state_dict) with torch.no_grad(): expected_output = model.eval()(*input) converted_output = converted_model.eval()(*input) if check_value: try: self.assertEqual(len(converted_output), len(expected_output)) for a, b in zip(converted_output, expected_output): torch.eq(a, b) except: self.assertEqual(converted_output, expected_output) return converted_model
def checkExportImport(self, model, input, check_value=True): model_ir = self._convert_model(model, input) model_code = model_to_pytorch_script(model_ir) print(model_code) exec_vars = {} exec(model_code + '\n\nconverted_model = _model()', exec_vars) converted_model = exec_vars['converted_model'] with original_state_dict_hooks(converted_model): converted_model.load_state_dict(model.state_dict()) with torch.no_grad(): expected_output = model.eval()(*input) converted_output = converted_model.eval()(*input) if check_value: self.assertEqual(len(converted_output), len(expected_output)) for a, b in zip(converted_output, expected_output): if hasattr(a, 'dtype') and a.dtype == torch.bool: self.assertEqual((a ^ b), False) elif isinstance((a - b), int): self.assertEqual((a - b), 0) else: self.assertLess((a - b).abs().max().item(), 1E-4) return converted_model
def checkExportImport(self, model, input): model_ir = self._convert_model(model, input) model_code = model_to_pytorch_script(model_ir) exec_vars = {} exec(model_code + '\n\nconverted_model = _model()', exec_vars) converted_model = exec_vars['converted_model'] with original_state_dict_hooks(converted_model): converted_model.load_state_dict(dict(model.state_dict())) with torch.no_grad(): expected_output = model.eval()(*input) converted_output = converted_model.eval()(*input) self.assertEqual(len(converted_output), len(expected_output)) for a, b in zip(converted_output, expected_output): self.assertLess((a - b).abs().max().item(), 1E-4) return converted_model
def checkExportImport(self, model, input): script_module = torch.jit.script(model) model_ir = convert_to_graph(script_module, model) model_code = model_to_pytorch_script(model_ir) exec_vars = {} exec(model_code + '\n\nconverted_model = _model()', exec_vars) converted_model = exec_vars['converted_model'] converted_state_dict = self._match_state_dict( list(model.state_dict().values()), dict(converted_model.state_dict())) converted_model.load_state_dict(converted_state_dict) with torch.no_grad(): expected_output = model.eval()(*input) converted_output = converted_model.eval()(*input) self.assertEqual(len(converted_output), len(expected_output)) for a, b in zip(converted_output, expected_output): self.assertLess((a - b).abs().max().item(), 1E-4) return converted_model
def run_test(self, model, input, check_value=True): model_ir = self._convert_model(model, input) model_code = model_to_pytorch_script(model_ir) print(model_code) exec_vars = {} exec(model_code + '\n\nconverted_model = _model()', exec_vars) converted_model = exec_vars['converted_model'] with original_state_dict_hooks(converted_model): converted_model.load_state_dict(model.state_dict()) with torch.no_grad(): expected_output = model.eval()(*input) converted_output = converted_model.eval()(*input) if check_value: try: self.assertEqual(len(converted_output), len(expected_output)) for a, b in zip(converted_output, expected_output): torch.eq(a, b) except: self.assertEqual(converted_output, expected_output) return converted_model
def checkExportImport(self, model, input, check_value=True): script_module = torch.jit.script(model) model_ir = convert_to_graph(script_module, model) model_code = model_to_pytorch_script(model_ir) #print(model_code) exec_vars = {} exec(model_code + '\n\nconverted_model = _model()', exec_vars) converted_model = exec_vars['converted_model'] converted_state_dict = self._match_state_dict( list(model.state_dict().values()), dict(converted_model.state_dict())) converted_model.load_state_dict(converted_state_dict) with torch.no_grad(): expected_output = model.eval()(*input) converted_output = converted_model.eval()(*input) if check_value: try: self.assertEqual(len(converted_output), len(expected_output)) for a, b in zip(converted_output, expected_output): torch.eq(a, b) except: self.assertEqual(converted_output, expected_output) return converted_model
def _get_converted_pytorch_model(self, model_ir): model_code = model_to_pytorch_script(model_ir) exec_vars = {} exec(model_code + '\n\nconverted_model = _model()', exec_vars) return exec_vars['converted_model']