def test_imported_classes(self): import jit._imported_class_test.foo import jit._imported_class_test.bar import jit._imported_class_test.very.very.nested class MyMod(torch.jit.ScriptModule): @torch.jit.script_method def forward(self, a): foo = jit._imported_class_test.foo.FooSameName(a) bar = jit._imported_class_test.bar.FooSameName(a) three = jit._imported_class_test.very.very.nested.FooUniqueName( a) return foo.x + bar.y + three.y m = MyMod() buffer = io.BytesIO() torch.jit.save(m, buffer) # classes are globally registered for now, so we need to clear the JIT # registry to simulate loading a new model jit_utils.clear_class_registry() buffer.seek(0) m_loaded = torch.jit.load(buffer) input = torch.rand(2, 3) output = m_loaded(input) self.assertEqual(3 * input, output)
def test_save_load_with_classes_returned(self): @torch.jit.script class FooTest(object): def __init__(self, x): self.x = x def clone(self): clone = FooTest(self.x) return clone class MyMod(torch.jit.ScriptModule): @torch.jit.script_method def forward(self, a): foo = FooTest(a) foo_clone = foo.clone() return foo_clone.x m = MyMod() buffer = io.BytesIO() torch.jit.save(m, buffer) # classes are globally registered for now, so we need to clear the JIT # registry to simulate loading a new model jit_utils.clear_class_registry() buffer.seek(0) m_loaded = torch.jit.load(buffer) input = torch.rand(2, 3) output = m_loaded(input) self.assertEqual(input, output)
def test_named_tuple_serialization(self): class MyCoolNamedTuple(NamedTuple): a: int b: float c: List[int] class MyMod(torch.jit.ScriptModule): @torch.jit.script_method def forward(self): return MyCoolNamedTuple(3, 3.5, [3, 4, 5]) mm = MyMod() mm.save('foo.zip') jit_utils.clear_class_registry() loaded = torch.jit.load('foo.zip') out = mm() out_loaded = loaded() for name in ['a', 'b', 'c']: self.assertEqual(getattr(out_loaded, name), getattr(out, name))
def test_save_load_with_classes_nested(self): @torch.jit.script # noqa: B903 class FooNestedTest(object): # noqa: B903 def __init__(self, y): self.y = y @torch.jit.script class FooNestedTest2(object): def __init__(self, y): self.y = y self.nested = FooNestedTest(y) @torch.jit.script class FooTest(object): def __init__(self, x): self.class_attr = FooNestedTest(x) self.class_attr2 = FooNestedTest2(x) self.x = self.class_attr.y + self.class_attr2.y class MyMod(torch.jit.ScriptModule): @torch.jit.script_method def forward(self, a): foo = FooTest(a) return foo.x m = MyMod() buffer = io.BytesIO() torch.jit.save(m, buffer) # classes are globally registered for now, so we need to clear the JIT # registry to simulate loading a new model jit_utils.clear_class_registry() buffer.seek(0) m_loaded = torch.jit.load(buffer) input = torch.rand(2, 3) output = m_loaded(input) self.assertEqual(2 * input, output)