Beispiel #1
0
 def test_codegen(self):
     with open(self.enclosing_dir / 'mnist_pytorch.json') as f:
         model = Model._load(json.load(f))
         script = model_to_pytorch_script(model)
     with open(self.enclosing_dir / 'debug_mnist_pytorch.py') as f:
         reference_script = f.read()
     self.assertEqual(script.strip(), reference_script.strip())
Beispiel #2
0
 def test_mnist_example_pytorch(self):
     with open('mnist_pytorch.json') as f:
         model = Model._load(json.load(f))
         script = model_to_pytorch_script(model)
     with open('debug_mnist_pytorch.py') as f:
         reference_script = f.read()
     self.assertEqual(script.strip(), reference_script.strip())
Beispiel #3
0
    def run_test(self, model, input, check_value=True):
        model_ir = self._convert_model(model, input)
        model_code = model_to_pytorch_script(model_ir)
        print(model_code)

        from .inject_nn import remove_inject_pytorch_nn
        remove_inject_pytorch_nn()

        exec_vars = {}
        exec(model_code + '\n\nconverted_model = _model()', exec_vars)
        converted_model = exec_vars['converted_model']
        converted_state_dict = self._match_state_dict(
            list(model.state_dict().values()),
            dict(converted_model.state_dict()))
        converted_model.load_state_dict(converted_state_dict)
        with torch.no_grad():
            expected_output = model.eval()(*input)
            converted_output = converted_model.eval()(*input)
        if check_value:
            try:
                self.assertEqual(len(converted_output), len(expected_output))
                for a, b in zip(converted_output, expected_output):
                    torch.eq(a, b)
            except:
                self.assertEqual(converted_output, expected_output)
        return converted_model
Beispiel #4
0
    def checkExportImport(self, model, input, check_value=True):
        model_ir = self._convert_model(model, input)
        model_code = model_to_pytorch_script(model_ir)
        print(model_code)

        exec_vars = {}
        exec(model_code + '\n\nconverted_model = _model()', exec_vars)
        converted_model = exec_vars['converted_model']

        with original_state_dict_hooks(converted_model):
            converted_model.load_state_dict(model.state_dict())

        with torch.no_grad():
            expected_output = model.eval()(*input)
            converted_output = converted_model.eval()(*input)
        if check_value:
            self.assertEqual(len(converted_output), len(expected_output))
            for a, b in zip(converted_output, expected_output):
                if hasattr(a, 'dtype') and a.dtype == torch.bool:
                    self.assertEqual((a ^ b), False)
                elif isinstance((a - b), int):
                    self.assertEqual((a - b), 0)
                else:
                    self.assertLess((a - b).abs().max().item(), 1E-4)
        return converted_model
Beispiel #5
0
    def checkExportImport(self, model, input):
        model_ir = self._convert_model(model, input)
        model_code = model_to_pytorch_script(model_ir)

        exec_vars = {}
        exec(model_code + '\n\nconverted_model = _model()', exec_vars)
        converted_model = exec_vars['converted_model']
        with original_state_dict_hooks(converted_model):
            converted_model.load_state_dict(dict(model.state_dict()))
        with torch.no_grad():
            expected_output = model.eval()(*input)
            converted_output = converted_model.eval()(*input)
        self.assertEqual(len(converted_output), len(expected_output))
        for a, b in zip(converted_output, expected_output):
            self.assertLess((a - b).abs().max().item(), 1E-4)
        return converted_model
Beispiel #6
0
    def checkExportImport(self, model, input):
        script_module = torch.jit.script(model)
        model_ir = convert_to_graph(script_module, model)
        model_code = model_to_pytorch_script(model_ir)

        exec_vars = {}
        exec(model_code + '\n\nconverted_model = _model()', exec_vars)
        converted_model = exec_vars['converted_model']
        converted_state_dict = self._match_state_dict(
            list(model.state_dict().values()),
            dict(converted_model.state_dict()))
        converted_model.load_state_dict(converted_state_dict)
        with torch.no_grad():
            expected_output = model.eval()(*input)
            converted_output = converted_model.eval()(*input)
        self.assertEqual(len(converted_output), len(expected_output))
        for a, b in zip(converted_output, expected_output):
            self.assertLess((a - b).abs().max().item(), 1E-4)
        return converted_model
Beispiel #7
0
    def run_test(self, model, input, check_value=True):
        model_ir = self._convert_model(model, input)
        model_code = model_to_pytorch_script(model_ir)
        print(model_code)

        exec_vars = {}
        exec(model_code + '\n\nconverted_model = _model()', exec_vars)
        converted_model = exec_vars['converted_model']

        with original_state_dict_hooks(converted_model):
            converted_model.load_state_dict(model.state_dict())

        with torch.no_grad():
            expected_output = model.eval()(*input)
            converted_output = converted_model.eval()(*input)
        if check_value:
            try:
                self.assertEqual(len(converted_output), len(expected_output))
                for a, b in zip(converted_output, expected_output):
                    torch.eq(a, b)
            except:
                self.assertEqual(converted_output, expected_output)
        return converted_model
    def checkExportImport(self, model, input, check_value=True):
        script_module = torch.jit.script(model)
        model_ir = convert_to_graph(script_module, model)
        model_code = model_to_pytorch_script(model_ir)
        #print(model_code)

        exec_vars = {}
        exec(model_code + '\n\nconverted_model = _model()', exec_vars)
        converted_model = exec_vars['converted_model']
        converted_state_dict = self._match_state_dict(
            list(model.state_dict().values()),
            dict(converted_model.state_dict()))
        converted_model.load_state_dict(converted_state_dict)
        with torch.no_grad():
            expected_output = model.eval()(*input)
            converted_output = converted_model.eval()(*input)
        if check_value:
            try:
                self.assertEqual(len(converted_output), len(expected_output))
                for a, b in zip(converted_output, expected_output):
                    torch.eq(a, b)
            except:
                self.assertEqual(converted_output, expected_output)
        return converted_model
Beispiel #9
0
 def _get_converted_pytorch_model(self, model_ir):
     model_code = model_to_pytorch_script(model_ir)
     exec_vars = {}
     exec(model_code + '\n\nconverted_model = _model()', exec_vars)
     return exec_vars['converted_model']