def test_different_modules(self): """ Exercise the situation where we have the same qualified name in two different CompilationUnits on save/load. """ class Foo(torch.nn.Module): def __init__(self): super(Foo, self).__init__() self.foo = torch.nn.Linear(2, 2) self.bar = torch.nn.Linear(2, 2) def forward(self, x): x = self.foo(x) x = self.bar(x) return x first_script_module = torch.jit.script(Foo()) first_saved_module = io.BytesIO() torch.jit.save(first_script_module, first_saved_module) first_saved_module.seek(0) clear_class_registry() class Foo(torch.nn.Module): def __init__(self): super(Foo, self).__init__() self.foo = torch.nn.Linear(2, 2) def forward(self, x): x = self.foo(x) return x second_script_module = torch.jit.script(Foo()) second_saved_module = io.BytesIO() torch.jit.save(torch.jit.script(Foo()), second_saved_module) second_saved_module.seek(0) clear_class_registry() self.assertEqual( first_script_module._c.qualified_name, second_script_module._c.qualified_name, ) class ContainsBoth(torch.nn.Module): def __init__(self): super().__init__() self.add_module("second", torch.jit.load(second_saved_module)) self.add_module("first", torch.jit.load(first_saved_module)) def forward(self, x): x = self.first(x) x = self.second(x) return x sm = torch.jit.script(ContainsBoth()) contains_both = io.BytesIO() torch.jit.save(sm, contains_both) contains_both.seek(0) sm = torch.jit.load(contains_both)
def test_different_functions(self): """ Exercise the situation where we have the same qualified name in two different CompilationUnits on save/load. """ def lol(x): return x class Foo(torch.nn.Module): def forward(self, x): return lol(x) first_script_module = torch.jit.script(Foo()) first_saved_module = io.BytesIO() torch.jit.save_jit_module_to_flatbuffer(first_script_module, first_saved_module) first_saved_module.seek(0) clear_class_registry() def lol(x): # noqa: F811 return "hello" class Foo(torch.nn.Module): def forward(self, x): return lol(x) second_script_module = torch.jit.script(Foo()) second_saved_module = io.BytesIO() torch.jit.save_jit_module_to_flatbuffer(torch.jit.script(Foo()), second_saved_module) second_saved_module.seek(0) clear_class_registry() self.assertEqual(first_script_module._c.qualified_name, second_script_module._c.qualified_name) class ContainsBoth(torch.nn.Module): def __init__(self): super().__init__() self.add_module( "second", torch.jit.jit_module_from_flatbuffer(second_saved_module)) self.add_module( "first", torch.jit.jit_module_from_flatbuffer(first_saved_module)) def forward(self, x): x = self.first(x) x = self.second(x) return x sm = torch.jit.script(ContainsBoth()) contains_both = io.BytesIO() torch.jit.save_jit_module_to_flatbuffer(sm, contains_both) contains_both.seek(0) sm = torch.jit.jit_module_from_flatbuffer(contains_both)
def test_many_collisions(self): class MyCoolNamedTuple(NamedTuple): a: int @torch.jit.interface class MyInterface(object): def bar(self, x): # type: (Tensor) -> Tensor pass @torch.jit.script class ImplementInterface(object): def __init__(self): pass def bar(self, x): return x def lol(x): return x class Foo(torch.nn.Module): interface: MyInterface def __init__(self): super().__init__() self.foo = torch.nn.Linear(2, 2) self.bar = torch.nn.Linear(2, 2) self.interface = ImplementInterface() def forward(self, x): x = self.foo(x) x = self.bar(x) x = lol(x) x = self.interface.bar(x) return x, MyCoolNamedTuple(a=5) first_script_module = torch.jit.script(Foo()) first_saved_module = io.BytesIO() torch.jit.save(first_script_module, first_saved_module) first_saved_module.seek(0) clear_class_registry() @torch.jit.interface class MyInterface(object): def not_bar(self, x): # type: (Tensor) -> Tensor pass @torch.jit.script # noqa F811 class ImplementInterface(object): # noqa F811 def __init__(self): pass def not_bar(self, x): return x def lol(x): # noqa F811 return "asdofij" class MyCoolNamedTuple(NamedTuple): # noqa F811 a: str class Foo(torch.nn.Module): interface: MyInterface def __init__(self): super().__init__() self.foo = torch.nn.Linear(2, 2) self.interface = ImplementInterface() def forward(self, x): x = self.foo(x) self.interface.not_bar(x) x = lol(x) return x, MyCoolNamedTuple(a="hello") second_script_module = torch.jit.script(Foo()) second_saved_module = io.BytesIO() torch.jit.save(second_script_module, second_saved_module) second_saved_module.seek(0) clear_class_registry() self.assertEqual(first_script_module._c.qualified_name, second_script_module._c.qualified_name) class ContainsBoth(torch.nn.Module): def __init__(self): super().__init__() self.add_module("second", torch.jit.load(second_saved_module)) self.add_module("first", torch.jit.load(first_saved_module)) def forward(self, x): x, named_tuple_1 = self.first(x) x, named_tuple_2 = self.second(x) return len(x + named_tuple_2.a) + named_tuple_1.a sm = torch.jit.script(ContainsBoth()) contains_both = io.BytesIO() torch.jit.save(sm, contains_both) contains_both.seek(0) sm = torch.jit.load(contains_both)
def test_different_interfaces(self): """ Exercise the situation where we have the same qualified name in two different CompilationUnits on save/load. """ @torch.jit.interface class MyInterface(object): def bar(self, x): # type: (Tensor) -> Tensor pass @torch.jit.script class ImplementInterface(object): def __init__(self): pass def bar(self, x): return x class Foo(torch.nn.Module): __annotations__ = {"interface": MyInterface} def __init__(self): super().__init__() self.interface = ImplementInterface() def forward(self, x): return self.interface.bar(x) first_script_module = torch.jit.script(Foo()) first_saved_module = io.BytesIO() torch.jit.save(first_script_module, first_saved_module) first_saved_module.seek(0) clear_class_registry() @torch.jit.interface class MyInterface(object): def not_bar(self, x): # type: (Tensor) -> Tensor pass @torch.jit.script # noqa: F811 class ImplementInterface(object): # noqa: F811 def __init__(self): pass def not_bar(self, x): return x class Foo(torch.nn.Module): __annotations__ = {"interface": MyInterface} def __init__(self): super().__init__() self.interface = ImplementInterface() def forward(self, x): return self.interface.not_bar(x) second_script_module = torch.jit.script(Foo()) second_saved_module = io.BytesIO() torch.jit.save(torch.jit.script(Foo()), second_saved_module) second_saved_module.seek(0) clear_class_registry() self.assertEqual(first_script_module._c.qualified_name, second_script_module._c.qualified_name) class ContainsBoth(torch.nn.Module): def __init__(self): super().__init__() self.add_module("second", torch.jit.load(second_saved_module)) self.add_module("first", torch.jit.load(first_saved_module)) def forward(self, x): x = self.first(x) x = self.second(x) return x sm = torch.jit.script(ContainsBoth()) contains_both = io.BytesIO() torch.jit.save(sm, contains_both) contains_both.seek(0) sm = torch.jit.load(contains_both)