示例#1
0
    def test_export_opnames_interface(self):
        @torch.jit.interface
        class OneTwoModule(nn.Module):
            def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                pass

            def two(self, x: torch.Tensor) -> torch.Tensor:
                pass

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                pass

        class FooMod(nn.Module):
            def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                return x + y

            def two(self, x: torch.Tensor) -> torch.Tensor:
                return 2 * x

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return self.one(self.two(x), x)

        class BarMod(nn.Module):
            def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                return x * y

            def two(self, x: torch.Tensor) -> torch.Tensor:
                return 2 / x

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return self.two(self.one(x, x))

        make_global(OneTwoModule)

        class M(nn.Module):
            sub: OneTwoModule

            def __init__(self):
                super(M, self).__init__()
                self.sub = BarMod()

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return self.sub.forward(x)

        def use_module_interface(mod_list: List[OneTwoModule],
                                 x: torch.Tensor):
            return mod_list[0].forward(x) + mod_list[1].forward(x)

        torch._C._enable_mobile_interface_call_export()
        scripted_M_mod = torch.jit.script(M())
        self.assertTrue(
            set(['aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal'
                 ]).issubset(set(torch.jit.export_opnames(scripted_M_mod))))

        scripted_M_mod.sub = torch.jit.script(FooMod())
        self.assertTrue(
            set(['aten::add.Tensor', 'aten::mul.Scalar'
                 ]).issubset(set(torch.jit.export_opnames(scripted_M_mod))))
示例#2
0
    def test_export_opnames_interface(self):

        @torch.jit.interface
        class OneTwoModule(nn.Module):
            def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                pass

            def two(self, x: torch.Tensor) -> torch.Tensor:
                pass

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                pass

        class FooMod(nn.Module):
            def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                return x + y

            def two(self, x: torch.Tensor) -> torch.Tensor:
                return 2 * x

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return self.one(self.two(x), x)

        class BarMod(nn.Module):
            def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                return x * y

            def two(self, x: torch.Tensor) -> torch.Tensor:
                return 2 / x

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return self.two(self.one(x, x))

        make_global(OneTwoModule)

        class M(nn.Module):
            sub : OneTwoModule

            def __init__(self):
                super(M, self).__init__()
                self.sub = BarMod()

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return self.sub.forward(x)

        def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor):
            return mod_list[0].forward(x) + mod_list[1].forward(x)

        scripted_M_mod = torch.jit.script(M())
        # Temporarily test empty output because lite interpreter does not support interface call
        # Replace it with the issubset call when interface call is supported.
        self.assertTrue(len(torch.jit.export_opnames(scripted_M_mod)) == 0)
        # self.assertTrue(set(['aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal']).issubset(
        #     set(torch.jit.export_opnames(scripted_M_mod))))

        scripted_M_mod.sub = torch.jit.script(FooMod())
        self.assertTrue(len(torch.jit.export_opnames(scripted_M_mod)) == 0)
示例#3
0
    def test_class_with_multiple_methods(self):
        class PDTModelWithManyMethods:
            def test_list_to_dict(self, a):
                new_dictionary: Dict[float, bool] = {}
                for element in a:
                    new_dictionary[element] = True
                return new_dictionary

            def test_substring(self, a, b):
                return b in a

        make_global(PDTModelWithManyMethods)
        pdt_model = PDTModelWithManyMethods()
        list_inp: List[Tuple[Any, ...]] = [
            ([
                1.2,
                2.3,
            ], ),
        ]
        str_inp: List[Tuple[Any, ...]] = [
            (
                "abc",
                "b",
            ),
        ]
        scripted_pdt_model = torch.jit.script(PDTModelWithManyMethods,
                                              example_inputs={
                                                  pdt_model.test_list_to_dict:
                                                  list_inp,
                                                  pdt_model.test_substring:
                                                  str_inp
                                              })
        script_model = scripted_pdt_model()
        self.assertEqual(script_model.test_list_to_dict([
            1.1,
            2.2,
            3.3,
        ], ), pdt_model.test_list_to_dict([
            1.1,
            2.2,
            3.3,
        ], ))
        self.assertEqual(script_model.test_substring(
            "helloworld",
            "world",
        ), pdt_model.test_substring(
            "helloworld",
            "world",
        ))
        self.assertEqual(script_model.test_substring(
            "helloworld",
            "def",
        ), pdt_model.test_substring(
            "helloworld",
            "def",
        ))
示例#4
0
    def test_any(self):
        def test_multiple_types(a):
            assert not isinstance(a, bool)
            return a

        def test_multiple_type_refinement(a):
            if isinstance(a, bool):
                return 1
            elif isinstance(a, int):
                return 1 + a
            elif isinstance(a, float):
                return 1 + int(a)
            else:
                return -1

        make_global(test_multiple_types, test_multiple_type_refinement)

        scripted_fn = torch.jit._script_pdt(test_multiple_types,
                                            example_inputs=[(1, ), ("abc", ),
                                                            (8.9, ),
                                                            ([3, 4, 5], )])
        self.assertEqual(scripted_fn(10), test_multiple_types(10))
        self.assertEqual(scripted_fn("def"), test_multiple_types("def"))
        self.assertEqual(scripted_fn(7.89999), test_multiple_types(7.89999))
        self.assertEqual(scripted_fn([10, 11, 14]),
                         test_multiple_types([10, 11, 14]))

        scripted_fn = torch.jit._script_pdt(test_multiple_type_refinement,
                                            example_inputs=[
                                                (1, ),
                                                ("abc", ),
                                                (8.9, ),
                                                ([3, 4, 5], ),
                                                (True, ),
                                                ({
                                                    "a": True
                                                }, ),
                                            ])
        self.assertEqual(scripted_fn(10), test_multiple_type_refinement(10))
        self.assertEqual(scripted_fn("def"),
                         test_multiple_type_refinement("def"))
        self.assertEqual(scripted_fn(7.89999),
                         test_multiple_type_refinement(7.89999))
        self.assertEqual(scripted_fn([10, 11, 14]),
                         test_multiple_type_refinement([10, 11, 14]))
        self.assertEqual(scripted_fn(False),
                         test_multiple_type_refinement(False))
        self.assertEqual(
            scripted_fn({
                "abc": True,
                "def": False
            }), test_multiple_type_refinement({
                "abc": True,
                "def": False
            }))
示例#5
0
    def test_class_methods(self):
        class PDTModel:
            def test_sum(self, a):
                return sum(a)

        make_global(PDTModel)
        pdt_model = PDTModel()
        inp: List[Tuple[Any, ...]] = [([10, 20, ], ), ]
        scripted_pdt_model = torch.jit._script_pdt(PDTModel, example_inputs={pdt_model.test_sum: inp})
        script_model = scripted_pdt_model()
        self.assertEqual(script_model.test_sum([10, 20, 30, ], ), pdt_model.test_sum([10, 20, 30, ], ))
示例#6
0
    def test_heterogenous_value_type_enum_error(self):
        class Color(Enum):
            RED = 1
            GREEN = "green"

        make_global(Color)

        def enum_comp(x: Color, y: Color) -> bool:
            return x == y

        with self.assertRaisesRegex(RuntimeError, "Could not unify type list"):
            torch.jit.script(enum_comp)
示例#7
0
    def test_fx_tracing_with_typing(self):
        class FXModelOutput(NamedTuple):
            result: List[int]

        class FXModel(torch.nn.Module):
            def forward(self, a) -> FXModelOutput:
                result = FXModelOutput(result=a)
                return result

        make_global(FXModel, FXModelOutput)
        pdt_model = FXModel()
        scripted_fn = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], })
        self.assertEqual(scripted_fn([20]), pdt_model([20]))
示例#8
0
    def test_multiple_class_with_same_method(self):
        class PDTModelOne:
            def test_find(self, a, b):
                return b in a.keys()

        class PDTModelTwo:
            def test_find(self, a, b):
                return b in a

        make_global(PDTModelOne, PDTModelTwo)
        pdt_model_one = PDTModelOne()
        pdt_model_two = PDTModelTwo()
        dict_inp: List[Tuple[Any, ...]] = [
            ({
                1.2: True,
                2.3: False,
            }, 1.2),
        ]
        list_inp: List[Tuple[Any, ...]] = [
            ([
                "abc",
                "b",
            ], "c"),
        ]
        scripted_pdt_model_one = torch.jit.script(
            PDTModelOne, example_inputs={pdt_model_one.test_find: dict_inp})
        scripted_pdt_model_two = torch.jit.script(
            PDTModelTwo, example_inputs={pdt_model_two.test_find: list_inp})

        script_model_one, script_model_two = scripted_pdt_model_one(
        ), scripted_pdt_model_two()
        self.assertEqual(
            script_model_one.test_find({
                1.1: True,
                2.2: True,
                3.3: False,
            }, 4.4),
            pdt_model_one.test_find({
                1.1: True,
                2.2: True,
                3.3: False,
            }, 4.4))
        self.assertEqual(
            script_model_two.test_find([
                "hello",
                "world",
            ], "world"), pdt_model_two.test_find([
                "hello",
                "world",
            ], "world"))
示例#9
0
    def test_non_existent_enum_value(self):
        class Color(Enum):
            RED = 1
            GREEN = 2

        make_global(Color)

        def enum_const(x: Color) -> bool:
            if x == Color.PURPLE:
                return True
            else:
                return False

        with self.assertRaisesRegexWithHighlight(RuntimeError, "has no attribute 'PURPLE'", "Color.PURPLE"):
            torch.jit.script(enum_const)
示例#10
0
    def test_heterogenous_value_type_enum_error(self):
        class Color(Enum):
            RED = 1
            GREEN = "green"

        make_global(Color)

        def enum_comp(x: Color, y: Color) -> bool:
            return x == y

        # TODO: rewrite code so that the highlight is not empty.
        with self.assertRaisesRegexWithHighlight(RuntimeError,
                                                 "Could not unify type list",
                                                 ""):
            torch.jit.script(enum_comp)
示例#11
0
    def test_enum_comp(self):
        class Color(Enum):
            RED = 1
            GREEN = 2

        make_global(Color)

        @torch.jit.script
        def enum_comp(x: Color, y: Color) -> bool:
            return x == y

        FileCheck().check("aten::eq").run(str(enum_comp.graph))

        self.assertEqual(enum_comp(Color.RED, Color.RED), True)
        self.assertEqual(enum_comp(Color.RED, Color.GREEN), False)
示例#12
0
    def test_cast_overloads(self):
        @torch.jit.script
        class Foo(object):
            def __init__(self, val: float) -> None:
                self.val = val

            def __int__(self):
                return int(self.val)

            def __float__(self):
                return self.val

            def __bool__(self):
                return bool(self.val)

            def __str__(self):
                return str(self.val)

        make_global(Foo)  # see [local resolution in python]

        def test(foo: Foo) -> Tuple[int, float, bool]:
            if foo:
                pass
            return int(foo), float(foo), bool(foo)

        fn = torch.jit.script(test)
        self.assertEqual(fn(Foo(0.5)), test(0.5))
        self.assertEqual(fn(Foo(0.)), test(0.0))
        # str has slightly different formatting
        self.assertTrue("0.5" in (str(Foo(0.5))))
        self.assertTrue("0." in (str(Foo(0.0))))

        @torch.jit.script
        class BadBool(object):
            def __init__(self):
                pass

            def __bool__(self):
                return (1, 2)

        with self.assertRaisesRegexWithHighlight(
                RuntimeError, "expected a bool expression for condition", ""):

            @torch.jit.script
            def test():
                if BadBool():
                    print(1)
                    pass
示例#13
0
    def test_nn_parameter_as_arg(self):
        class TestNNParameter(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.inp = torch.nn.Parameter(torch.ones(2, 3))

            def add_nn_parameter_with_int(self, x, y):
                return torch.add(x, y)

            def forward(self, y):
                return self.add_nn_parameter_with_int(self.inp, y)

        make_global(TestNNParameter)
        pdt_model = TestNNParameter()
        scripted_fn = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: [(10, ), ], })
        self.assertEqual(scripted_fn(20), pdt_model(20))
示例#14
0
    def test_enum_return(self):
        class Color(Enum):
            RED = 1
            GREEN = 2

        make_global(Color)

        @torch.jit.script
        def return_enum(cond: bool):
            if cond:
                return Color.RED
            else:
                return Color.GREEN

        self.assertEqual(return_enum(True), Color.RED)
        self.assertEqual(return_enum(False), Color.GREEN)
示例#15
0
    def test_pdt_dict(self):
        def test_dict(a):
            return a['foo']

        def test_dict_int_list(a):
            return a[1]

        make_global(test_dict, test_dict_int_list)

        str_bool_inp = {'foo' : True, 'bar': False}
        scripted_fn = torch.jit._script_pdt(test_dict, example_inputs=[(str_bool_inp,)])
        self.assertEqual(scripted_fn({'foo' : False, 'bar': True}, ), test_dict({'foo' : False, 'bar': True}, ))

        str_list_inp = {0 : [True, False], 1: [False, True]}
        scripted_fn = torch.jit._script_pdt(test_dict_int_list, example_inputs=[(str_list_inp,)])
        self.assertEqual(scripted_fn({0 : [False, False], 1: [True, True]}, ),
                         test_dict_int_list({0 : [False, False], 1: [True, True]}, ))
示例#16
0
    def test_pdt_list_and_tuple(self):
        def test_list_and_tuple(a):
            return sum(a)

        make_global(test_list_and_tuple)

        scripted_fn_float_list_input = torch.jit.script(test_list_and_tuple,
                                                        example_inputs=[
                                                            ([4.9, 8.9], )
                                                        ])
        self.assertEqual(scripted_fn_float_list_input([11.9, 7.6]),
                         test_list_and_tuple([11.9, 7.6]))

        scripted_fn_bool_list_input = torch.jit.script(
            test_list_and_tuple, example_inputs=[([True, False, True], )])
        self.assertEqual(scripted_fn_bool_list_input([True, True, True]),
                         test_list_and_tuple([True, True, True]))

        scripted_fn_int_list_input = torch.jit.script(test_list_and_tuple,
                                                      example_inputs=[
                                                          ([3, 4, 5], )
                                                      ])
        self.assertEqual(scripted_fn_int_list_input([1, 2, 3]),
                         test_list_and_tuple([1, 2, 3]))

        scripted_fn_float_tuple_input = torch.jit.script(test_list_and_tuple,
                                                         example_inputs=[
                                                             ((4.9, 8.9), )
                                                         ])
        self.assertEqual(scripted_fn_float_tuple_input((11.9, 7.6)),
                         test_list_and_tuple((11.9, 7.6)))

        scripted_fn_bool_tuple_input = torch.jit.script(test_list_and_tuple,
                                                        example_inputs=[
                                                            ((True, False,
                                                              True), )
                                                        ])
        self.assertEqual(scripted_fn_bool_tuple_input((True, True, True)),
                         test_list_and_tuple((True, True, True)))

        scripted_fn_int_tuple_input = torch.jit.script(test_list_and_tuple,
                                                       example_inputs=[
                                                           ((3, 4, 5), )
                                                       ])
        self.assertEqual(scripted_fn_int_tuple_input((1, 2, 3)),
                         test_list_and_tuple((1, 2, 3)))
示例#17
0
    def test_string_enum_as_module_attribute(self):
        class Color(Enum):
            RED = "red"
            GREEN = "green"

        class TestModule(torch.nn.Module):
            def __init__(self, e: Color):
                super(TestModule, self).__init__()
                self.e = e

            def forward(self):
                return (self.e.name, self.e.value)

        make_global(Color)
        m = TestModule(Color.RED)
        scripted = torch.jit.script(m)

        self.assertEqual(scripted(), (Color.RED.name, Color.RED.value))
示例#18
0
    def test_class_type_as_param(self):
        class FooTest(object):  # noqa: B903
            def __init__(self, x):
                self.attr = x

        make_global(FooTest)  # see [local resolution in python]

        @torch.jit.script
        def fn(foo: FooTest) -> torch.Tensor:
            return foo.attr

        @torch.jit.script
        def fn2(x):
            foo = FooTest(x)
            return fn(foo)

        input = torch.ones(1)
        self.assertEqual(fn2(input), input)
示例#19
0
    def test_enum_ivalue_type(self):
        class Color(Enum):
            RED = 1
            GREEN = 2

        make_global(Color)

        @torch.jit.script
        def is_color_enum(x: Any):
            return isinstance(x, Color)

        FileCheck() \
            .check("prim::isinstance[types=[Enum<__torch__.jit.test_enum.Color>]]") \
            .check_next("return") \
            .run(str(is_color_enum.graph))

        self.assertEqual(is_color_enum(Color.RED), True)
        self.assertEqual(is_color_enum(Color.GREEN), True)
        self.assertEqual(is_color_enum(1), False)
示例#20
0
    def test_enum_as_const(self):
        class Color(Enum):
            RED = 1
            GREEN = 2

        make_global(Color)

        @torch.jit.script
        def enum_const(x: Color) -> bool:
            return x == Color.RED

        FileCheck() \
            .check("prim::Constant[value=__torch__.jit.test_enum.Color.RED]") \
            .check_next("aten::eq") \
            .check_next("return") \
            .run(str(enum_const.graph))

        self.assertEqual(enum_const(Color.RED), True)
        self.assertEqual(enum_const(Color.GREEN), False)
示例#21
0
    def test_enum_value(self):
        class Color(Enum):
            RED = 1
            GREEN = 2

        make_global(Color)

        @torch.jit.script
        def enum_value(x: Color) -> int:
            return x.value

        FileCheck() \
            .check("Color") \
            .check_next("prim::EnumValue") \
            .check_next("return") \
            .run(str(enum_value.graph))

        self.assertEqual(enum_value(Color.RED), Color.RED.value)
        self.assertEqual(enum_value(Color.GREEN), Color.GREEN.value)
示例#22
0
    def test_py_class_to_ivalue_missing_attribute(self):
        class Foo(object):
            i : int
            f : float

            def __init__(self, i : int, f : float):
                self.i = i
                self.f = f

        make_global(Foo)  # see [local resolution in python]

        @torch.jit.script
        def test_fn(x : Foo) -> float:
            return x.i + x.f

        test_fn(Foo(3, 4.0))

        with self.assertRaisesRegexWithHighlight(RuntimeError, 'missing attribute i', ""):
            test_fn(torch.rand(3, 4))
示例#23
0
    def test_union_with_enum(self):
        class Color(Enum):
            RED = 1
            GREEN = 2

        make_global(Color)

        def fn(x: Union[str, Color]) -> str:
            return "foo"

        self.checkScript(fn, (Color.RED,))
        self.checkScript(fn, ("red",))

        scripted = torch.jit.script(fn)

        with self.assertRaisesRegex(RuntimeError, "Expected a member of"
                                    r" Union\[__torch__.jit.test_union."
                                    r"Color, str\] but instead found "
                                    "type int"):
            scripted(1)
示例#24
0
    def test_nn_module(self):
        class TestPDTModel(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x) -> Any:
                if isinstance(x, int):
                    return x + 1
                elif isinstance(x, float):
                    return x - 1
                else:
                    return x

        make_global(TestPDTModel)
        pdt_model = TestPDTModel()
        inp: List[Tuple[Any, ...]] = [(20, ), (2.7, ), (False, ), ]
        scripted_pdt_model = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: inp})
        self.assertEqual(scripted_pdt_model(50), pdt_model(50))
        self.assertEqual(scripted_pdt_model(1.8), pdt_model(1.8))
        self.assertTrue(scripted_pdt_model(True), pdt_model(True))
示例#25
0
    def test_class_as_profiled_types(self):
        class UserDefinedClass:
            def fn(self, b) -> Any:
                assert b is not None
                if isinstance(b, int):
                    return b if b > 0 else -1
                elif isinstance(b, float):
                    return b if b > 0.0 else -1.0
                return 0

        def test_model(a, m):
            assert not isinstance(a, bool)
            return m.fn(a)

        make_global(UserDefinedClass, test_model)

        user_class = UserDefinedClass()
        scripted_fn = torch.jit._script_pdt(test_model, example_inputs=[(10, user_class, ), (10.9, user_class, ), ])
        self.assertEqual(scripted_fn(100, user_class, ), test_model(100, user_class))
        self.assertEqual(scripted_fn(1.9, user_class, ), test_model(1.9, user_class))
示例#26
0
    def test_class_with_args_as_profiled_types(self):
        class ClassWithArgs:
            def __init__(self, a: bool):
                self.a = a

            def fn(self, b):
                if self.a:
                    return b
                else:
                    return -1

        def test_model_with_args(a, m):
            assert not isinstance(a, bool)
            return m.fn(a)

        make_global(ClassWithArgs, test_model_with_args)

        user_class = ClassWithArgs(False)
        scripted_fn = torch.jit._script_pdt(test_model_with_args, example_inputs=[(10, user_class, ), (10.9, user_class, ), ])
        self.assertEqual(scripted_fn(100, ClassWithArgs(True), ), test_model_with_args(100, ClassWithArgs(True)))
示例#27
0
    def test_nested_function_in_forward(self):
        class NestedFunctionInForward(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                return self.fun(x) + 10

            def fun(self, x):
                if isinstance(x, bool):
                    return 0
                elif isinstance(x, int):
                    return x + 1
                return 0

        make_global(NestedFunctionInForward)
        pdt_model = NestedFunctionInForward()
        inp: List[Tuple[Any, ...]] = [(-1, ), (False, )]
        scripted_pdt_model = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: inp})
        self.assertEqual(scripted_pdt_model(30), pdt_model(30))
        self.assertEqual(scripted_pdt_model(True), pdt_model(True))
示例#28
0
    def test_enum_value_types(self):
        class IntEnum(Enum):
            FOO = 1
            BAR = 2

        class FloatEnum(Enum):
            FOO = 1.2
            BAR = 2.3

        class StringEnum(Enum):
            FOO = "foo as in foo bar"
            BAR = "bar as in foo bar"

        make_global(IntEnum, FloatEnum, StringEnum)

        @torch.jit.script
        def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum):
            return (a.name, b.name, c.name)

        FileCheck() \
            .check("IntEnum") \
            .check("FloatEnum") \
            .check("StringEnum") \
            .run(str(supported_enum_types.graph))

        class TensorEnum(Enum):
            FOO = torch.tensor(0)
            BAR = torch.tensor(1)

        make_global(TensorEnum)

        def unsupported_enum_types(a: TensorEnum):
            return a.name

        # TODO: rewrite code so that the highlight is not empty.
        with self.assertRaisesRegexWithHighlight(
                RuntimeError, "Cannot create Enum with value type 'Tensor'",
                ""):
            torch.jit.script(unsupported_enum_types)
示例#29
0
    def test_enum_comp_diff_classes(self):
        class Foo(Enum):
            ITEM1 = 1
            ITEM2 = 2

        class Bar(Enum):
            ITEM1 = 1
            ITEM2 = 2

        make_global(Foo, Bar)

        @torch.jit.script
        def enum_comp(x: Foo) -> bool:
            return x == Bar.ITEM1

        FileCheck() \
            .check("prim::Constant") \
            .check_same("Bar.ITEM1") \
            .check("aten::eq") \
            .run(str(enum_comp.graph))

        self.assertEqual(enum_comp(Foo.ITEM1), False)
示例#30
0
    def test_staticmethod(self):
        """
        Test static methods on class types.
        """
        @torch.jit.script
        class ClassWithStaticMethod:
            def __init__(self, a: int, b: int):
                self.a: int = a
                self.b: int = b

            def get_a(self):
                return self.a

            def get_b(self):
                return self.b

            def __eq__(self, other: 'ClassWithStaticMethod'):
                return self.a == other.a and self.b == other.b

            # staticmethod that calls constructor.
            @staticmethod
            def create(
                args: List['ClassWithStaticMethod']
            ) -> 'ClassWithStaticMethod':
                return ClassWithStaticMethod(args[0].a, args[0].b)

            # staticmethod that calls another staticmethod.
            @staticmethod
            def create_from(a: int, b: int) -> 'ClassWithStaticMethod':
                a = ClassWithStaticMethod(a, b)
                return ClassWithStaticMethod.create([a])

        # Script function that calls staticmethod.
        def test_function(a: int, b: int) -> 'ClassWithStaticMethod':
            return ClassWithStaticMethod.create_from(a, b)

        make_global(ClassWithStaticMethod)

        self.checkScript(test_function, (1, 2))