def test_export_opnames_interface(self): @torch.jit.interface class OneTwoModule(nn.Module): def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: pass def two(self, x: torch.Tensor) -> torch.Tensor: pass def forward(self, x: torch.Tensor) -> torch.Tensor: pass class FooMod(nn.Module): def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y def two(self, x: torch.Tensor) -> torch.Tensor: return 2 * x def forward(self, x: torch.Tensor) -> torch.Tensor: return self.one(self.two(x), x) class BarMod(nn.Module): def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x * y def two(self, x: torch.Tensor) -> torch.Tensor: return 2 / x def forward(self, x: torch.Tensor) -> torch.Tensor: return self.two(self.one(x, x)) make_global(OneTwoModule) class M(nn.Module): sub: OneTwoModule def __init__(self): super(M, self).__init__() self.sub = BarMod() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.sub.forward(x) def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor): return mod_list[0].forward(x) + mod_list[1].forward(x) torch._C._enable_mobile_interface_call_export() scripted_M_mod = torch.jit.script(M()) self.assertTrue( set(['aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal' ]).issubset(set(torch.jit.export_opnames(scripted_M_mod)))) scripted_M_mod.sub = torch.jit.script(FooMod()) self.assertTrue( set(['aten::add.Tensor', 'aten::mul.Scalar' ]).issubset(set(torch.jit.export_opnames(scripted_M_mod))))
def test_export_opnames_interface(self): @torch.jit.interface class OneTwoModule(nn.Module): def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: pass def two(self, x: torch.Tensor) -> torch.Tensor: pass def forward(self, x: torch.Tensor) -> torch.Tensor: pass class FooMod(nn.Module): def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y def two(self, x: torch.Tensor) -> torch.Tensor: return 2 * x def forward(self, x: torch.Tensor) -> torch.Tensor: return self.one(self.two(x), x) class BarMod(nn.Module): def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x * y def two(self, x: torch.Tensor) -> torch.Tensor: return 2 / x def forward(self, x: torch.Tensor) -> torch.Tensor: return self.two(self.one(x, x)) make_global(OneTwoModule) class M(nn.Module): sub : OneTwoModule def __init__(self): super(M, self).__init__() self.sub = BarMod() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.sub.forward(x) def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor): return mod_list[0].forward(x) + mod_list[1].forward(x) scripted_M_mod = torch.jit.script(M()) # Temporarily test empty output because lite interpreter does not support interface call # Replace it with the issubset call when interface call is supported. self.assertTrue(len(torch.jit.export_opnames(scripted_M_mod)) == 0) # self.assertTrue(set(['aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal']).issubset( # set(torch.jit.export_opnames(scripted_M_mod)))) scripted_M_mod.sub = torch.jit.script(FooMod()) self.assertTrue(len(torch.jit.export_opnames(scripted_M_mod)) == 0)
def test_class_with_multiple_methods(self): class PDTModelWithManyMethods: def test_list_to_dict(self, a): new_dictionary: Dict[float, bool] = {} for element in a: new_dictionary[element] = True return new_dictionary def test_substring(self, a, b): return b in a make_global(PDTModelWithManyMethods) pdt_model = PDTModelWithManyMethods() list_inp: List[Tuple[Any, ...]] = [ ([ 1.2, 2.3, ], ), ] str_inp: List[Tuple[Any, ...]] = [ ( "abc", "b", ), ] scripted_pdt_model = torch.jit.script(PDTModelWithManyMethods, example_inputs={ pdt_model.test_list_to_dict: list_inp, pdt_model.test_substring: str_inp }) script_model = scripted_pdt_model() self.assertEqual(script_model.test_list_to_dict([ 1.1, 2.2, 3.3, ], ), pdt_model.test_list_to_dict([ 1.1, 2.2, 3.3, ], )) self.assertEqual(script_model.test_substring( "helloworld", "world", ), pdt_model.test_substring( "helloworld", "world", )) self.assertEqual(script_model.test_substring( "helloworld", "def", ), pdt_model.test_substring( "helloworld", "def", ))
def test_any(self): def test_multiple_types(a): assert not isinstance(a, bool) return a def test_multiple_type_refinement(a): if isinstance(a, bool): return 1 elif isinstance(a, int): return 1 + a elif isinstance(a, float): return 1 + int(a) else: return -1 make_global(test_multiple_types, test_multiple_type_refinement) scripted_fn = torch.jit._script_pdt(test_multiple_types, example_inputs=[(1, ), ("abc", ), (8.9, ), ([3, 4, 5], )]) self.assertEqual(scripted_fn(10), test_multiple_types(10)) self.assertEqual(scripted_fn("def"), test_multiple_types("def")) self.assertEqual(scripted_fn(7.89999), test_multiple_types(7.89999)) self.assertEqual(scripted_fn([10, 11, 14]), test_multiple_types([10, 11, 14])) scripted_fn = torch.jit._script_pdt(test_multiple_type_refinement, example_inputs=[ (1, ), ("abc", ), (8.9, ), ([3, 4, 5], ), (True, ), ({ "a": True }, ), ]) self.assertEqual(scripted_fn(10), test_multiple_type_refinement(10)) self.assertEqual(scripted_fn("def"), test_multiple_type_refinement("def")) self.assertEqual(scripted_fn(7.89999), test_multiple_type_refinement(7.89999)) self.assertEqual(scripted_fn([10, 11, 14]), test_multiple_type_refinement([10, 11, 14])) self.assertEqual(scripted_fn(False), test_multiple_type_refinement(False)) self.assertEqual( scripted_fn({ "abc": True, "def": False }), test_multiple_type_refinement({ "abc": True, "def": False }))
def test_class_methods(self): class PDTModel: def test_sum(self, a): return sum(a) make_global(PDTModel) pdt_model = PDTModel() inp: List[Tuple[Any, ...]] = [([10, 20, ], ), ] scripted_pdt_model = torch.jit._script_pdt(PDTModel, example_inputs={pdt_model.test_sum: inp}) script_model = scripted_pdt_model() self.assertEqual(script_model.test_sum([10, 20, 30, ], ), pdt_model.test_sum([10, 20, 30, ], ))
def test_heterogenous_value_type_enum_error(self): class Color(Enum): RED = 1 GREEN = "green" make_global(Color) def enum_comp(x: Color, y: Color) -> bool: return x == y with self.assertRaisesRegex(RuntimeError, "Could not unify type list"): torch.jit.script(enum_comp)
def test_fx_tracing_with_typing(self): class FXModelOutput(NamedTuple): result: List[int] class FXModel(torch.nn.Module): def forward(self, a) -> FXModelOutput: result = FXModelOutput(result=a) return result make_global(FXModel, FXModelOutput) pdt_model = FXModel() scripted_fn = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], }) self.assertEqual(scripted_fn([20]), pdt_model([20]))
def test_multiple_class_with_same_method(self): class PDTModelOne: def test_find(self, a, b): return b in a.keys() class PDTModelTwo: def test_find(self, a, b): return b in a make_global(PDTModelOne, PDTModelTwo) pdt_model_one = PDTModelOne() pdt_model_two = PDTModelTwo() dict_inp: List[Tuple[Any, ...]] = [ ({ 1.2: True, 2.3: False, }, 1.2), ] list_inp: List[Tuple[Any, ...]] = [ ([ "abc", "b", ], "c"), ] scripted_pdt_model_one = torch.jit.script( PDTModelOne, example_inputs={pdt_model_one.test_find: dict_inp}) scripted_pdt_model_two = torch.jit.script( PDTModelTwo, example_inputs={pdt_model_two.test_find: list_inp}) script_model_one, script_model_two = scripted_pdt_model_one( ), scripted_pdt_model_two() self.assertEqual( script_model_one.test_find({ 1.1: True, 2.2: True, 3.3: False, }, 4.4), pdt_model_one.test_find({ 1.1: True, 2.2: True, 3.3: False, }, 4.4)) self.assertEqual( script_model_two.test_find([ "hello", "world", ], "world"), pdt_model_two.test_find([ "hello", "world", ], "world"))
def test_non_existent_enum_value(self): class Color(Enum): RED = 1 GREEN = 2 make_global(Color) def enum_const(x: Color) -> bool: if x == Color.PURPLE: return True else: return False with self.assertRaisesRegexWithHighlight(RuntimeError, "has no attribute 'PURPLE'", "Color.PURPLE"): torch.jit.script(enum_const)
def test_heterogenous_value_type_enum_error(self): class Color(Enum): RED = 1 GREEN = "green" make_global(Color) def enum_comp(x: Color, y: Color) -> bool: return x == y # TODO: rewrite code so that the highlight is not empty. with self.assertRaisesRegexWithHighlight(RuntimeError, "Could not unify type list", ""): torch.jit.script(enum_comp)
def test_enum_comp(self): class Color(Enum): RED = 1 GREEN = 2 make_global(Color) @torch.jit.script def enum_comp(x: Color, y: Color) -> bool: return x == y FileCheck().check("aten::eq").run(str(enum_comp.graph)) self.assertEqual(enum_comp(Color.RED, Color.RED), True) self.assertEqual(enum_comp(Color.RED, Color.GREEN), False)
def test_cast_overloads(self): @torch.jit.script class Foo(object): def __init__(self, val: float) -> None: self.val = val def __int__(self): return int(self.val) def __float__(self): return self.val def __bool__(self): return bool(self.val) def __str__(self): return str(self.val) make_global(Foo) # see [local resolution in python] def test(foo: Foo) -> Tuple[int, float, bool]: if foo: pass return int(foo), float(foo), bool(foo) fn = torch.jit.script(test) self.assertEqual(fn(Foo(0.5)), test(0.5)) self.assertEqual(fn(Foo(0.)), test(0.0)) # str has slightly different formatting self.assertTrue("0.5" in (str(Foo(0.5)))) self.assertTrue("0." in (str(Foo(0.0)))) @torch.jit.script class BadBool(object): def __init__(self): pass def __bool__(self): return (1, 2) with self.assertRaisesRegexWithHighlight( RuntimeError, "expected a bool expression for condition", ""): @torch.jit.script def test(): if BadBool(): print(1) pass
def test_nn_parameter_as_arg(self): class TestNNParameter(torch.nn.Module): def __init__(self): super().__init__() self.inp = torch.nn.Parameter(torch.ones(2, 3)) def add_nn_parameter_with_int(self, x, y): return torch.add(x, y) def forward(self, y): return self.add_nn_parameter_with_int(self.inp, y) make_global(TestNNParameter) pdt_model = TestNNParameter() scripted_fn = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: [(10, ), ], }) self.assertEqual(scripted_fn(20), pdt_model(20))
def test_enum_return(self): class Color(Enum): RED = 1 GREEN = 2 make_global(Color) @torch.jit.script def return_enum(cond: bool): if cond: return Color.RED else: return Color.GREEN self.assertEqual(return_enum(True), Color.RED) self.assertEqual(return_enum(False), Color.GREEN)
def test_pdt_dict(self): def test_dict(a): return a['foo'] def test_dict_int_list(a): return a[1] make_global(test_dict, test_dict_int_list) str_bool_inp = {'foo' : True, 'bar': False} scripted_fn = torch.jit._script_pdt(test_dict, example_inputs=[(str_bool_inp,)]) self.assertEqual(scripted_fn({'foo' : False, 'bar': True}, ), test_dict({'foo' : False, 'bar': True}, )) str_list_inp = {0 : [True, False], 1: [False, True]} scripted_fn = torch.jit._script_pdt(test_dict_int_list, example_inputs=[(str_list_inp,)]) self.assertEqual(scripted_fn({0 : [False, False], 1: [True, True]}, ), test_dict_int_list({0 : [False, False], 1: [True, True]}, ))
def test_pdt_list_and_tuple(self): def test_list_and_tuple(a): return sum(a) make_global(test_list_and_tuple) scripted_fn_float_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[ ([4.9, 8.9], ) ]) self.assertEqual(scripted_fn_float_list_input([11.9, 7.6]), test_list_and_tuple([11.9, 7.6])) scripted_fn_bool_list_input = torch.jit.script( test_list_and_tuple, example_inputs=[([True, False, True], )]) self.assertEqual(scripted_fn_bool_list_input([True, True, True]), test_list_and_tuple([True, True, True])) scripted_fn_int_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[ ([3, 4, 5], ) ]) self.assertEqual(scripted_fn_int_list_input([1, 2, 3]), test_list_and_tuple([1, 2, 3])) scripted_fn_float_tuple_input = torch.jit.script(test_list_and_tuple, example_inputs=[ ((4.9, 8.9), ) ]) self.assertEqual(scripted_fn_float_tuple_input((11.9, 7.6)), test_list_and_tuple((11.9, 7.6))) scripted_fn_bool_tuple_input = torch.jit.script(test_list_and_tuple, example_inputs=[ ((True, False, True), ) ]) self.assertEqual(scripted_fn_bool_tuple_input((True, True, True)), test_list_and_tuple((True, True, True))) scripted_fn_int_tuple_input = torch.jit.script(test_list_and_tuple, example_inputs=[ ((3, 4, 5), ) ]) self.assertEqual(scripted_fn_int_tuple_input((1, 2, 3)), test_list_and_tuple((1, 2, 3)))
def test_string_enum_as_module_attribute(self): class Color(Enum): RED = "red" GREEN = "green" class TestModule(torch.nn.Module): def __init__(self, e: Color): super(TestModule, self).__init__() self.e = e def forward(self): return (self.e.name, self.e.value) make_global(Color) m = TestModule(Color.RED) scripted = torch.jit.script(m) self.assertEqual(scripted(), (Color.RED.name, Color.RED.value))
def test_class_type_as_param(self): class FooTest(object): # noqa: B903 def __init__(self, x): self.attr = x make_global(FooTest) # see [local resolution in python] @torch.jit.script def fn(foo: FooTest) -> torch.Tensor: return foo.attr @torch.jit.script def fn2(x): foo = FooTest(x) return fn(foo) input = torch.ones(1) self.assertEqual(fn2(input), input)
def test_enum_ivalue_type(self): class Color(Enum): RED = 1 GREEN = 2 make_global(Color) @torch.jit.script def is_color_enum(x: Any): return isinstance(x, Color) FileCheck() \ .check("prim::isinstance[types=[Enum<__torch__.jit.test_enum.Color>]]") \ .check_next("return") \ .run(str(is_color_enum.graph)) self.assertEqual(is_color_enum(Color.RED), True) self.assertEqual(is_color_enum(Color.GREEN), True) self.assertEqual(is_color_enum(1), False)
def test_enum_as_const(self): class Color(Enum): RED = 1 GREEN = 2 make_global(Color) @torch.jit.script def enum_const(x: Color) -> bool: return x == Color.RED FileCheck() \ .check("prim::Constant[value=__torch__.jit.test_enum.Color.RED]") \ .check_next("aten::eq") \ .check_next("return") \ .run(str(enum_const.graph)) self.assertEqual(enum_const(Color.RED), True) self.assertEqual(enum_const(Color.GREEN), False)
def test_enum_value(self): class Color(Enum): RED = 1 GREEN = 2 make_global(Color) @torch.jit.script def enum_value(x: Color) -> int: return x.value FileCheck() \ .check("Color") \ .check_next("prim::EnumValue") \ .check_next("return") \ .run(str(enum_value.graph)) self.assertEqual(enum_value(Color.RED), Color.RED.value) self.assertEqual(enum_value(Color.GREEN), Color.GREEN.value)
def test_py_class_to_ivalue_missing_attribute(self): class Foo(object): i : int f : float def __init__(self, i : int, f : float): self.i = i self.f = f make_global(Foo) # see [local resolution in python] @torch.jit.script def test_fn(x : Foo) -> float: return x.i + x.f test_fn(Foo(3, 4.0)) with self.assertRaisesRegexWithHighlight(RuntimeError, 'missing attribute i', ""): test_fn(torch.rand(3, 4))
def test_union_with_enum(self): class Color(Enum): RED = 1 GREEN = 2 make_global(Color) def fn(x: Union[str, Color]) -> str: return "foo" self.checkScript(fn, (Color.RED,)) self.checkScript(fn, ("red",)) scripted = torch.jit.script(fn) with self.assertRaisesRegex(RuntimeError, "Expected a member of" r" Union\[__torch__.jit.test_union." r"Color, str\] but instead found " "type int"): scripted(1)
def test_nn_module(self): class TestPDTModel(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x) -> Any: if isinstance(x, int): return x + 1 elif isinstance(x, float): return x - 1 else: return x make_global(TestPDTModel) pdt_model = TestPDTModel() inp: List[Tuple[Any, ...]] = [(20, ), (2.7, ), (False, ), ] scripted_pdt_model = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: inp}) self.assertEqual(scripted_pdt_model(50), pdt_model(50)) self.assertEqual(scripted_pdt_model(1.8), pdt_model(1.8)) self.assertTrue(scripted_pdt_model(True), pdt_model(True))
def test_class_as_profiled_types(self): class UserDefinedClass: def fn(self, b) -> Any: assert b is not None if isinstance(b, int): return b if b > 0 else -1 elif isinstance(b, float): return b if b > 0.0 else -1.0 return 0 def test_model(a, m): assert not isinstance(a, bool) return m.fn(a) make_global(UserDefinedClass, test_model) user_class = UserDefinedClass() scripted_fn = torch.jit._script_pdt(test_model, example_inputs=[(10, user_class, ), (10.9, user_class, ), ]) self.assertEqual(scripted_fn(100, user_class, ), test_model(100, user_class)) self.assertEqual(scripted_fn(1.9, user_class, ), test_model(1.9, user_class))
def test_class_with_args_as_profiled_types(self): class ClassWithArgs: def __init__(self, a: bool): self.a = a def fn(self, b): if self.a: return b else: return -1 def test_model_with_args(a, m): assert not isinstance(a, bool) return m.fn(a) make_global(ClassWithArgs, test_model_with_args) user_class = ClassWithArgs(False) scripted_fn = torch.jit._script_pdt(test_model_with_args, example_inputs=[(10, user_class, ), (10.9, user_class, ), ]) self.assertEqual(scripted_fn(100, ClassWithArgs(True), ), test_model_with_args(100, ClassWithArgs(True)))
def test_nested_function_in_forward(self): class NestedFunctionInForward(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return self.fun(x) + 10 def fun(self, x): if isinstance(x, bool): return 0 elif isinstance(x, int): return x + 1 return 0 make_global(NestedFunctionInForward) pdt_model = NestedFunctionInForward() inp: List[Tuple[Any, ...]] = [(-1, ), (False, )] scripted_pdt_model = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: inp}) self.assertEqual(scripted_pdt_model(30), pdt_model(30)) self.assertEqual(scripted_pdt_model(True), pdt_model(True))
def test_enum_value_types(self): class IntEnum(Enum): FOO = 1 BAR = 2 class FloatEnum(Enum): FOO = 1.2 BAR = 2.3 class StringEnum(Enum): FOO = "foo as in foo bar" BAR = "bar as in foo bar" make_global(IntEnum, FloatEnum, StringEnum) @torch.jit.script def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum): return (a.name, b.name, c.name) FileCheck() \ .check("IntEnum") \ .check("FloatEnum") \ .check("StringEnum") \ .run(str(supported_enum_types.graph)) class TensorEnum(Enum): FOO = torch.tensor(0) BAR = torch.tensor(1) make_global(TensorEnum) def unsupported_enum_types(a: TensorEnum): return a.name # TODO: rewrite code so that the highlight is not empty. with self.assertRaisesRegexWithHighlight( RuntimeError, "Cannot create Enum with value type 'Tensor'", ""): torch.jit.script(unsupported_enum_types)
def test_enum_comp_diff_classes(self): class Foo(Enum): ITEM1 = 1 ITEM2 = 2 class Bar(Enum): ITEM1 = 1 ITEM2 = 2 make_global(Foo, Bar) @torch.jit.script def enum_comp(x: Foo) -> bool: return x == Bar.ITEM1 FileCheck() \ .check("prim::Constant") \ .check_same("Bar.ITEM1") \ .check("aten::eq") \ .run(str(enum_comp.graph)) self.assertEqual(enum_comp(Foo.ITEM1), False)
def test_staticmethod(self): """ Test static methods on class types. """ @torch.jit.script class ClassWithStaticMethod: def __init__(self, a: int, b: int): self.a: int = a self.b: int = b def get_a(self): return self.a def get_b(self): return self.b def __eq__(self, other: 'ClassWithStaticMethod'): return self.a == other.a and self.b == other.b # staticmethod that calls constructor. @staticmethod def create( args: List['ClassWithStaticMethod'] ) -> 'ClassWithStaticMethod': return ClassWithStaticMethod(args[0].a, args[0].b) # staticmethod that calls another staticmethod. @staticmethod def create_from(a: int, b: int) -> 'ClassWithStaticMethod': a = ClassWithStaticMethod(a, b) return ClassWithStaticMethod.create([a]) # Script function that calls staticmethod. def test_function(a: int, b: int) -> 'ClassWithStaticMethod': return ClassWithStaticMethod.create_from(a, b) make_global(ClassWithStaticMethod) self.checkScript(test_function, (1, 2))