Esempio n. 1
0
 def wrapper_save():
     _MODELS.append(cls)
     model = cls()
     scripted = torch.jit.script(model)
     buffer = BytesIO(scripted._save_to_buffer_for_lite_interpreter())
     buffer.seek(0)
     mobile_module = _load_for_lite_interpreter(buffer)
     ops = _export_operator_list(mobile_module)
     _OPERATORS.update(ops)
     path = f"./{cls.__name__}.ptl"
     _FILENAMES.append(path)
     scripted._save_for_lite_interpreter(path)
    def test_module_export_operator_list(self):
        class Foo(torch.nn.Module):
            def __init__(self):
                super(Foo, self).__init__()
                self.weight = torch.ones((20, 1, 5, 5))
                self.bias = torch.ones(20)

            def forward(self, input):
                x1 = torch.zeros(2, 2)
                x2 = torch.empty_like(torch.empty(2, 2))
                x3 = torch._convolution(
                    input,
                    self.weight,
                    self.bias,
                    [1, 1],
                    [0, 0],
                    [1, 1],
                    False,
                    [0, 0],
                    1,
                    False,
                    False,
                    True,
                    True,
                )
                return (x1, x2, x3)

        m = torch.jit.script(Foo())

        buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
        buffer.seek(0)
        mobile_module = _load_for_lite_interpreter(buffer)

        expected_ops = {
            "aten::_convolution",
            "aten::empty.memory_format",
            "aten::empty_like",
            "aten::zeros",
        }
        actual_ops = _export_operator_list(mobile_module)
        self.assertEqual(actual_ops, expected_ops)
Esempio n. 3
0
def get_operator_list(script_module: torch) -> Set[str]:
    buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
    buffer.seek(0)
    mobile_module = _load_for_lite_interpreter(buffer)
    operator_list = _export_operator_list(mobile_module)
    return operator_list