コード例 #1
0
    def test_tensor_type_not_determined_by_inputs(self):
        @torch.jit.script
        def scalar_type_input(x, y, z):
            return x + y + 4 + z.item()

        x = torch.tensor([2, 2])
        scalar_type_input(x, x, torch.tensor(1))
        scalar_type_input(x, x, torch.tensor(1))
        scalar_type_input(x, x, torch.tensor(1.0))
        g = torch.jit.last_executed_optimized_graph()

        # item & add should not get pulled into the fusion group -
        # we expect to see Fusion Group (item / add) Fusion Group in ir dump
        FileCheck().check("TensorExpr").check(
            "Scalar = aten::item").check_next("Tensor = aten::add").check(
                "TensorExpr").run(g)

        @torch.jit.script
        def non_const_dtype(x, y, cond: bool):
            dtype = torch.int16 if cond else torch.int32
            return (x + y + 3).sum(dtype=dtype)

        non_const_dtype(x, x, True)
        non_const_dtype(x, x, True)
        g = torch.jit.last_executed_optimized_graph()
        # because dtype is non-const, sum should not get pulled into the Fusion Group
        FileCheck().check("TensorExpr").check("TensorExpr").check_not(
            "aten::sum").run(g)
コード例 #2
0
ファイル: test_profiler.py プロジェクト: yuguo68/pytorch
    def test_specialized_types(self):
        @torch.jit.script
        def test_fuse(a, b):
            c = a * b
            d = c * b
            return d

        x = torch.tensor([.5])
        for _ in range(3):
            test_fuse(x, x)

        g = torch.jit.last_executed_optimized_graph()
        # Types should remain specialized for typecheck outputs & fusion outputs
        FileCheck().check("Double(").check_same("prim::TypeCheck").check_same("\n").check("Double").check_same("TensorExpr").run(g)

        # other outputs should not be specialized
        FileCheck().check("Tensor = prim::If").run(g)
コード例 #3
0
    def test_not_fusing_scalar_ops(self):
        @torch.jit.script
        def foo(x: int, y: int):
            return x + y + 2 + 4 + 5 + 6

        foo(1, 2)
        foo(2, 3)
        g = torch.jit.last_executed_optimized_graph()
        FileCheck().check_not("TensorExpr").run(g)
コード例 #4
0
    def test_not_optimizing_property(self):
        @torch.jit.script
        def foo(x, y):
            return x + y + 1 + 2 + 3, x.size()

        x = torch.ones(1)
        foo(x, x)
        foo(x, x)
        g = torch.jit.last_executed_optimized_graph()
        FileCheck().check("aten::size").run(g)
        x = torch.ones([2, 3, 5])
        self.assertEqual(foo(x, x), (x + x + 1 + 2 + 3, x.size()))
コード例 #5
0
    def test_tensor_constant(self):
        def foo(a, b):
            return a + b + torch.tensor([2])

        x = torch.ones(1, requires_grad=False)
        foo_script = torch.jit.script(foo)
        foo_script(x, x)
        foo_script(x, x)

        self.assertEqual(foo_script(x, x), foo(x, x))
        g = torch.jit.last_executed_optimized_graph()
        FileCheck().check_count("aten::add", 2, exactly=True).run(g)
コード例 #6
0
ファイル: test_profiler.py プロジェクト: yuguo68/pytorch
    def test_fallback_graph_not_specialized(self):
        @torch.jit.script
        def foo(a, b):
            c = a * b
            d = c * b
            e = d * b
            return d + e

        x = torch.ones(1)
        y = torch.ones(1)
        foo(x, y)
        foo(x, y)
        g = torch.jit.last_executed_optimized_graph()
        FileCheck().check("CallFunction").check_next("Tensor = prim::TupleUnpack").run(g)
コード例 #7
0
ファイル: test_profiler.py プロジェクト: yuguo68/pytorch
    def test_aliasing_merge(self):
        @torch.jit.script
        def foo(a, b):
            c = a * b
            d = c * b
            d.add_(b)
            e = d * b
            return d + e

        x = torch.ones(1)
        y = torch.ones(1)
        foo(x, y)
        b = foo(x, y)
        g = torch.jit.last_executed_optimized_graph()
        self.assertEqual(len(list(g.findAllNodes("prim::TypeCheck"))), 2)
        FileCheck().check("TensorExpr").check("aten::add_").check("TensorExpr").run(g)
コード例 #8
0
ファイル: test_profiler.py プロジェクト: yuguo68/pytorch
    def test_local_fusion_strategy(self):
        @torch.jit.script
        def foo(x):
            return x + x + x

        torch.jit.set_fusion_strategy([("STATIC", 1)])
        for _ in range(3):
            foo(torch.rand([10]))

        torch.jit.set_fusion_strategy([("STATIC", 10)])

        for i in range(10):
            foo(torch.rand([i]))
            foo(torch.rand([i]))

        g = torch.jit.last_executed_optimized_graph()
        FileCheck().check_count(":TensorExprGroup", 2, exactly=True).run(g)
コード例 #9
0
ファイル: test_profiler.py プロジェクト: yuguo68/pytorch
    def test_autograd_fallback_graph(self):
        @torch.jit.script
        def foo(a, b):
            c = a * b
            d = c * b
            e = d * b
            return d + e

        x = torch.ones(1, requires_grad=True)
        y = torch.ones(1, requires_grad=True)
        foo(x, y)
        b = foo(x, y)
        b.backward(torch.ones([1], dtype=torch.float), retain_graph=True)
        b.backward(torch.ones([1], dtype=torch.float))

        g = torch.jit.last_executed_optimized_graph()
        FileCheck().check("fallback_function").check_next("CallFunction").run(g)
コード例 #10
0
ファイル: test_profiler.py プロジェクト: yuguo68/pytorch
    def test_use_not_profiled(self):
        def foo(t1, t2, t3, t4, t: float):
            h = t1 + t2 + t3 + t4
            if t > 0.5:
                # Putting a use of t1 in a never-executed conditional prevents
                return t1 + 1
            return h

        t = torch.rand(8, dtype=torch.float)

        foo_script = torch.jit.script(foo)
        for _ in range(torch._C._jit_get_num_profiled_runs() + 1):
            foo_script(t, t, t, t, 0.1)

        self.assertEqual(foo(t, t, t, t, 0.1), foo_script(t, t, t, t, 0.1))
        g = torch.jit.last_executed_optimized_graph()
        # all adds fused
        FileCheck().check("graph").check_not("aten::add").check("prim::If").run(g)
コード例 #11
0
    def test_specialize_backward(self):
        def test_fuse(a, b):
            c = a * b
            d = c * b
            return d

        test_fuse.__disable_jit_function_caching__ = True

        scripted_f = torch.jit.script(test_fuse)
        x = torch.ones(1, requires_grad=True)
        y = torch.ones(1, requires_grad=True)
        scripted_f(x, y)
        b = scripted_f(x, y)
        warmup_backward(b)
        g = torch.jit.last_executed_optimized_graph()
        # Backward has an if node guarding specializations,
        # within the if node true block there is only one if node
        # that guards a tensorexpr group
        optimized_block = next(g.findNode("prim::If").blocks())
        if_nodes = list(optimized_block.findAllNodes("prim::If"))

        self.assertEqual(len(if_nodes), 1)
        FileCheck().check("Group[Subgraph").run(str(if_nodes[0]))
        # no broadcasts occurred, sum_to_size have been specialized out
        self.assertIsNone(optimized_block.findNode("aten::_grad_sum_to_size"))

        broadcast_f = torch.jit.script(test_fuse)
        x = torch.ones([2, 2], requires_grad=True)
        y = torch.ones([1], requires_grad=True)
        broadcast_f(x, y)
        b = broadcast_f(x, y)
        b.backward(torch.ones([2, 2], dtype=torch.float), retain_graph=True)
        b.backward(torch.ones([2, 2], dtype=torch.float))
        # warmup_backward(b, torch.ones([2, 2], dtype=torch.float))
        g = torch.jit.last_executed_optimized_graph()
        optimized_block = next(g.findNode("prim::If").blocks())
        # broadcasts occurred, currently expect to see aten::_grad_sum_to_size
        self.assertIsNotNone(
            optimized_block.findNode("aten::_grad_sum_to_size"))