Пример #1
0
        def checkCmath(func_name, funcs_template=funcs_template):
            funcs_str = funcs_template.format(func_or_const=func_name)
            scope = {}
            execWrapper(funcs_str, globals(), scope)
            cu = torch.jit.CompilationUnit(funcs_str)
            f_script = cu.func
            f = scope['func']

            for a in complex_vals:
                res_python = None
                res_script = None
                try:
                    res_python = f(a)
                except Exception as e:
                    res_python = e
                try:
                    res_script = f_script(a)
                except Exception as e:
                    res_script = e

                if res_python != res_script:
                    if isinstance(res_python, Exception):
                        continue

                    msg = f"Failed on {func_name} with input {a}. Python: {res_python}, Script: {res_script}"
                    self.assertEqual(res_python, res_script, msg=msg)
Пример #2
0
    def test_ops_bound_in_functional(self):
        ops_bound_in_functional = "unique",
        tensor = torch.tensor([2])
        funcs_template = dedent('''
        def func():
            return torch.{op}()
        ''')
        for op in ops_bound_in_functional:
            funcs_str = funcs_template.format(op=op)
            scope = {}
            execWrapper(funcs_str, globals(), scope)
            f = scope['func']
            with self.assertRaisesRegex(Exception, "Unknown builtin op"):
                cu = torch.jit.CompilationUnit(funcs_str)

        def unique_consec():
            x = torch.tensor([1])
            return torch.unique_consecutive(x,
                                            return_inverse=False,
                                            return_counts=True,
                                            dim=0)

        self.assertNotEqual(unique_consec(), torch.jit.script(unique_consec)())

        def tensordot():
            a = torch.arange(60.).reshape(3, 4, 5)
            b = torch.arange(24.).reshape(4, 3, 2)
            torch.tensordot(a, b, dims=([1, 0], [0, 1]))

        tensordot()
        with self.assertRaisesRegex(Exception, "Argument dims_self"):
            torch.jit.script(tensordot)
Пример #3
0
        def checkCmath(func_name, funcs_template=funcs_template):
            funcs_str = funcs_template.format(func_or_const=func_name)
            scope = {}
            execWrapper(funcs_str, globals(), scope)
            cu = torch.jit.CompilationUnit(funcs_str)
            f_script = cu.func
            f = scope['func']

            if func_name in ['isinf', 'isnan', 'isfinite']:
                new_vals = vals + ([float('inf'), float('nan'), -1 * float('inf')])
                final_vals = tuple(complex(x, y) for x, y in product(new_vals, new_vals))
            else:
                final_vals = complex_vals

            for a in final_vals:
                res_python = None
                res_script = None
                try:
                    res_python = f(a)
                except Exception as e:
                    res_python = e
                try:
                    res_script = f_script(a)
                except Exception as e:
                    res_script = e

                if res_python != res_script:
                    if isinstance(res_python, Exception):
                        continue

                    msg = f"Failed on {func_name} with input {a}. Python: {res_python}, Script: {res_script}"
                    self.assertEqual(res_python, res_script, msg=msg)
Пример #4
0
    def test_ops_bound_in_functional(self):
        ops_bound_in_functional = "lu_unpack", "unique", "lu"

        tensor = torch.tensor([2])
        funcs_template = dedent('''
        def func():
            return torch.{op}()
        ''')
        for op in ops_bound_in_functional:
            funcs_str = funcs_template.format(op=op)
            scope = {}
            execWrapper(funcs_str, globals(), scope)
            f = scope['func']
            with self.assertRaisesRegex(Exception, "Unknown builtin op"):
                cu = torch.jit.CompilationUnit(funcs_str)

        def fn():
            a = torch.tensor([[0.9041, 0.0196], [-0.3108, -2.4423],
                              [-0.4821, 1.059]])
            b = torch.tensor([[-2.1763, -0.4713], [-0.6986, 1.3702]])
            return torch.cdist(a, b, compute_mode="use_mm_for_euclid_dist")

        fn()
        with self.assertRaisesRegex(Exception, "Expected a value of type"):
            torch.jit.script(fn)

        def norm():
            c = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float)
            return torch.norm(c, p="fro")

        norm()
        with self.assertRaisesRegex(Exception, "Expected a value of type"):
            torch.jit.script(norm)

        def unique_consec():
            x = torch.tensor([1])
            return torch.unique_consecutive(x,
                                            return_inverse=False,
                                            return_counts=True,
                                            dim=0)

        self.assertNotEqual(unique_consec(), torch.jit.script(unique_consec)())

        def tensordot():
            a = torch.arange(60.).reshape(3, 4, 5)
            b = torch.arange(24.).reshape(4, 3, 2)
            torch.tensordot(a, b, dims=([1, 0], [0, 1]))

        tensordot()
        with self.assertRaisesRegex(Exception, "Argument dims_self"):
            torch.jit.script(tensordot)
Пример #5
0
 def test_index_ellipses(self):
     vals = [":", 1, None]
     for _ in range(100):
         indices = [random.choice(vals) for _ in range(4)]
         indices[random.randint(0, len(indices) - 1)] = "..."
         test_str = dedent("""
         def f():
             x = torch.ones(10, 9, 8, 7, 6)
             return x{indices}.shape
         """.format(indices=indices))
         test_str = test_str.replace(r"'", r'')
         scope = {}
         execWrapper(test_str, globals(), scope)
         cu = torch.jit.CompilationUnit(test_str)
         res1 = cu.f()
         res2 = scope['f']()
         self.assertEqual(res1, res2)
    def test_arange_shape(self):
        # no opinfo for tensor constructors
        inps = [
            (10, ),
            (10, 10),
            (0, 10),
            (0, 1000),
            (1, -1, -1),
            (1, 0, -1),
            (1, 2, 1),
            (0.6, 0.89, 0.1),
            (1, 10, 0.3),
            (1, 10, 4),
            (0.6, 0.7, 0.8),
            (1, 10, 0.3),
            # (True,),  TODO: https://github.com/pytorch/pytorch/issues/63405
            # (False,), TODO: https://github.com/pytorch/pytorch/issues/63405
            (0, 5),
            (0, 5, 2),
            (0, 5 + 1e-6),
            (0, 5 - 1e-6),
            (10, -1 + 1e-6, -1),
            (10, -1, -1),
            (10, -1 - 1e-6, -1),
        ]

        for inp in inps:
            funcs_template = dedent('''
            def func():
                return torch.arange({args})
            ''')

            inp_s = str(inp)[1:-1]  # remove tuple parens
            funcs_str = funcs_template.format(args=inp_s)
            scope = {}
            execWrapper(funcs_str, globals(), scope)
            cu = torch.jit.CompilationUnit(funcs_str)
            self.checkShapeAnalysis(list(cu.func().size()),
                                    cu.func.graph,
                                    assert_propagation=True,
                                    constant_prop=False)
Пример #7
0
        def checkMath(func_name):
            funcs_template = dedent('''
            def func(a: complex):
                return cmath.{func}(a)
            ''')

            funcs_str = funcs_template.format(func=func_name)
            scope = {}
            execWrapper(funcs_str, globals(), scope)
            cu = torch.jit.CompilationUnit(funcs_str)
            f_script = cu.func
            f = scope['func']

            for a in complex_vals:
                res_python = None
                res_script = None
                try:
                    res_python = f(a)
                except Exception as e:
                    res_python = e
                try:
                    res_script = f_script(a)
                except Exception as e:
                    res_script = e

                if res_python != res_script:
                    if isinstance(res_python, Exception):
                        continue

                    msg = (
                        "Failed on {func_name} with input {a}. Python: {res_python}, Script: {res_script}"
                        .format(func_name=func_name,
                                a=a,
                                res_python=res_python,
                                res_script=res_script))
                    self.assertEqual(res_python, res_script, msg=msg)