Exemplo n.º 1
0
    def test_longer_chain(self):
        """
           sin     relu     cos     sigmoid     tanh
        a ====> b =====> c ====> d ========> e =====> f
        """

        class TestModule(torch.nn.Module):
            def forward(self, a):
                b = torch.sin(a)
                c = torch.relu(b)
                d = torch.cos(c)
                e = torch.sigmoid(d)
                f = torch.tanh(e)
                return f

        mod = acc_tracer.trace(TestModule(), torch.randn(2, 3))

        # Making relu and sigmoid execute on ACC
        splitter = TRTSplitter(
            mod,
            (torch.randn(2, 3),),
            op_support_with_support_dict(
                {
                    "acc_ops.relu": None,
                    "acc_ops.sigmoid": None,
                }
            ),
        )

        def test_splitter(splitter):
            st_split = splitter()
            try:
                verify_split_model(st_split)
            except RuntimeError as err:
                self.assertEqual(
                    str(err), ERROR_MSG_MULTI_ACC_MODULES
                )
            [arg] = find_inputs(st_split)

            # First subgraph calculates b = sin(a) on CPU
            [sin] = find_fun_calls(st_split._run_on_gpu_0, acc_ops.sin)
            self.assertEqual(arg.name, sin.kwargs["input"].name)

            # Second subgraph calculates c = relu(b) on ACC
            [relu] = find_fun_calls(st_split._run_on_acc_1, acc_ops.relu)
            self.assertEqual(sin.name, relu.kwargs["input"].name)

            # Third subgraph calculates d = cos(c) on CPU
            [cos] = find_fun_calls(st_split._run_on_gpu_2, acc_ops.cos)
            self.assertEqual(relu.name, cos.kwargs["input"].name)

            # Fourth subgraph calculates e = sigmoid(d) on ACC
            [sigmoid] = find_fun_calls(st_split._run_on_acc_3, acc_ops.sigmoid)
            self.assertEqual(cos.name, sigmoid.kwargs["input"].name)

            # Fifth subgraph calculates f = tanh(e) on CPU
            [tanh] = find_fun_calls(st_split._run_on_gpu_4, acc_ops.tanh)
            self.assertEqual(sigmoid.name, tanh.kwargs["input"].name)

        test_splitter(splitter)
Exemplo n.º 2
0
    def test_nothing_to_split(self):
        class SimpleModule(torch.nn.Module):
            def forward(self, a):
                return a

        mod = acc_tracer.trace(SimpleModule(), torch.randn(2, 3))

        # Mark any operation as runnable on ACC
        class CustomOpSupport(op_support.OperatorSupportBase):
            def is_node_supported(self, submodules, node):
                return True

        splitter = TRTSplitter(
            mod, (torch.randn(2, 3),), CustomOpSupport()
        )

        def test_splitter(splitter):
            st_split = splitter()
            try:
                verify_split_model(st_split)
            except RuntimeError as err:
                self.assertEqual(
                    str(err), ERROR_MSG_NO_ACC_MODULE
                )
            self.assertEqual(splitter.module.__dict__.keys(), st_split.__dict__.keys())

        test_splitter(splitter)
Exemplo n.º 3
0
    def test_start_with_acc_module_(self):
        """
           sin     relu     cos     sigmoid     tanh
        a ====> b =====> c ====> d ========> e =====> f

        We set sin, relu and cos as acc node but also set min_acc_module_size to 2
        and expect the whole module stay on CPU.
        """

        class TestModule(torch.nn.Module):
            def forward(self, a):
                b = torch.sin(a)
                c = torch.relu(b)
                d = torch.cos(c)
                e = torch.sigmoid(d)
                f = torch.tanh(e)
                return f

        mod = acc_tracer.trace(TestModule(), torch.randn(2, 3))

        # Set sin, cos and tanh as acc node and split with settings
        class CustomOpSupport(op_support.OperatorSupport):
            _support_dict = {
                "acc_ops.sin": None,
                "acc_ops.cos": None,
                "acc_ops.relu": None,
            }

        # Create splitter setting and set min_acc_module_size to 2
        settings = splitter_base._SplitterSettingBase()
        settings.min_acc_module_size = 2
        splitter = TRTSplitter(
            mod,
            (torch.randn(2, 3),),
            op_support_with_support_dict(
                {
                    "acc_ops.sin": None,
                    "acc_ops.cos": None,
                    "acc_ops.relu": None,
                }
            ),
            settings,
        )

        def test_splitter(splitter):
            st_split = splitter()
            try:
                verify_split_model(st_split)
            except RuntimeError as err:
                self.assertEqual(
                    str(err), ERROR_MSG_NO_ACC_MODULE
                )
            modules = list(st_split.named_modules())
            # Main module and a submodule
            assert len(modules) == 3

            assert modules[1][0] == "_run_on_acc_0"
            assert modules[2][0] == "_run_on_gpu_1"

        test_splitter(splitter)
Exemplo n.º 4
0
    def test_mod_with_getattr(self):
        """
        CPU subgraph should have get_attr for self.a while ACC subgraph
        should have get_attr for self.b.
        """

        class SimpleModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.a = torch.randn(1, 1, 1, 1)
                self.b = torch.randn(1, 1, 1, 1)
                self.conv = torch.nn.Conv2d(1, 1, 1)
                self.linear = torch.nn.Linear(1, 1)

            def forward(self, x):
                x = x + self.a
                x = self.conv(x)
                return self.linear(x - self.b)

        mod = acc_tracer.trace(SimpleModule(), torch.randn(1, 1, 1, 1))
        mod.eval()

        splitter = TRTSplitter(
            mod,
            (torch.randn(1, 1, 1, 1),),
            op_support_with_support_dict(
                {
                    "acc_ops.linear": None,
                    "acc_ops.sub": None,
                }
            ),
        )

        def test_splitter(splitter):
            st_split = splitter()
            verify_split_model(st_split)
            # Should be "a", "conv.weight", "conv.bias".
            get_attr_nodes = [
                node.target
                for node in st_split._run_on_gpu_0.graph.nodes
                if node.op == "get_attr"
            ]
            assert len(get_attr_nodes) == 3 and "a" in get_attr_nodes

            # Should be "b", "conv.weight", "conv.bias".
            get_attr_nodes = [
                node.target
                for node in st_split._run_on_acc_1.graph.nodes
                if node.op == "get_attr"
            ]
            assert len(get_attr_nodes) == 3 and "b" in get_attr_nodes

        test_splitter(splitter)
Exemplo n.º 5
0
    def test_split_non_tensor_edges_3(self):
        test_data = torch.randn(2, 3)

        module_nn = acc_tracer.trace(
            self.TestModule(),
            (test_data, ),
        )

        # Making 'a', 'c', 'd' and 'e' run on ACC
        splitter = TRTSplitter(
            module_nn,
            (test_data, ),
            op_support_with_support_dict({
                "acc_ops.relu": None,
                "acc_ops.sigmoid": None,
                "acc_ops.cos": None,
                "acc_ops.add": None,
            }),
        )

        def test_splitter(splitter):
            module_fx_split = splitter()
            try:
                verify_split_model(module_fx_split)
            except RuntimeError as err:
                self.assertEqual(str(err), ERROR_MSG_MULTI_ACC_MODULES)

            self.assertEqual(
                {acc_ops.relu, acc_ops.cos},
                find_call_targets(module_fx_split._run_on_acc_0),
            )

            self.assertEqual(
                {acc_ops.size, acc_ops.getitem, acc_ops.add},
                find_call_targets(module_fx_split._run_on_cpu_1),
            )

            self.assertEqual(
                {acc_ops.sigmoid},
                find_call_targets(module_fx_split._run_on_acc_2),
            )

            # Make sure we can compile to TorchScript
            module_jit = torch.jit.trace_module(module_fx_split,
                                                {"forward": test_data})
            self.assertTrue(
                torch.allclose(module_nn(test_data), module_jit(test_data)))

        test_splitter(splitter)
Exemplo n.º 6
0
    def test_get_attr_into_output(self):
        """
        Here we verify the case when get_attr node is consumed directly by the
        output. We don't expect any split to happen in this test, just want to
        make sure that the splitter code doesn't break.
        """

        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.a = torch.randn(2, 3)

            def forward(self, x):
                return (x, self.a)

        # No need to put anything on ACC.
        class TestOperatorSupport:
            def is_node_supported(self, submodules, node):
                return False

        module_original = acc_tracer.trace(TestModule(), torch.randn(4, 5))

        splitter = TRTSplitter(
            module=module_original,
            sample_input=torch.randn(4, 5),
            operator_support=TestOperatorSupport(),
        )

        def test_splitter(splitter):
            module_split = splitter()
            try:
                verify_split_model(module_split)
            except RuntimeError as err:
                self.assertEqual(
                    str(err), ERROR_MSG_NO_ACC_MODULE
                )

            output = find_output(module_split)
            # Second argument of the output should be get_attr.
            self.assertEqual("get_attr", output.args[0][1].op)

            # Check if modules are equivalent.
            tensor = torch.randn(10, 20)
            result_original = module_original(tensor)
            result_split = module_split(tensor)
            self.assertTrue(torch.equal(result_original[0], result_split[0]))
            self.assertTrue(torch.equal(result_original[1], result_split[1]))

        test_splitter(splitter)
Exemplo n.º 7
0
    def _trt_split(self, graph: fx.GraphModule,
                   input: Input) -> fx.GraphModule:
        splitter_settings = _SplitterSettingBase()
        splitter_settings.min_acc_module_size = self.min_acc_module_size

        splitter = TRTSplitter(
            graph,
            input,  # type: ignore[arg-type]
            self.operator_supported,
            settings=splitter_settings,
        )
        logger.info(f"""{splitter.node_support_preview.__name__}: {
            splitter.node_support_preview()
            }""")
        return splitter()
Exemplo n.º 8
0
    def test_split_non_tensor_edges_4(self):
        test_data = torch.randn(2, 3)

        module_nn = acc_tracer.trace(
            self.TestModule(),
            (test_data, ),
        )

        # Making 'a', 'c', 'd' and 'e' run on ACC with limit on ACC
        # subgraph size
        settings = splitter_base._SplitterSettingBase()
        settings.min_acc_module_size = 2
        splitter = TRTSplitter(
            module_nn,
            (test_data, ),
            op_support_with_support_dict({
                "acc_ops.relu": None,
                "acc_ops.sigmoid": None,
                "acc_ops.cos": None,
                "acc_ops.add": None,
            }),
            settings,
        )

        def test_splitter(splitter):
            module_fx_split = splitter()
            verify_split_model(module_fx_split)

            self.assertEqual(
                {acc_ops.relu, acc_ops.cos},
                find_call_targets(module_fx_split._run_on_acc_0),
            )

            self.assertEqual(
                {acc_ops.size, acc_ops.getitem, acc_ops.add, acc_ops.sigmoid},
                find_call_targets(module_fx_split._run_on_cpu_1),
            )

            # Make sure we can compile to TorchScript
            module_jit = torch.jit.trace_module(module_fx_split,
                                                {"forward": test_data})
            self.assertTrue(
                torch.allclose(module_nn(test_data), module_jit(test_data)))

        test_splitter(splitter)
Exemplo n.º 9
0
    def test_demo(self):
        """
          ==> b ==>
        //         \\
       a             d
        \\         //
          ==> c ==>
        """

        class SimpleModule(torch.nn.Module):
            def forward(self, a):
                b = torch.sin(a)
                c = torch.cos(a)
                d = b + c
                return d

        mod = acc_tracer.trace(SimpleModule(), torch.randn(2, 3))

        # Making b and c run on ACC
        splitter = TRTSplitter(
            mod,
            (torch.randn(2, 3),),
            op_support_with_support_dict(
                {
                    "acc_ops.sin": None,
                    "acc_ops.cos": None,
                }
            ),
        )

        st_split = splitter()

        [arg] = find_inputs(st_split)

        # First subgraph calculates b = sin(a) and c = cos(a) on ACC
        [sin] = find_fun_calls(st_split._run_on_acc_0, acc_ops.sin)
        self.assertEqual(arg.name, sin.kwargs["input"].name)

        [cos] = find_fun_calls(st_split._run_on_acc_0, acc_ops.cos)
        self.assertEqual(arg.name, cos.kwargs["input"].name)

        # Second subgraph calculates d = b + c on CPU
        [add] = find_fun_calls(st_split._run_on_gpu_1, acc_ops.add)
        self.assertEqual(sin.name, add.kwargs["input"].name)
        self.assertEqual(cos.name, add.kwargs["other"].name)
Exemplo n.º 10
0
    def test_get_attr_into_starter_node(self):
        """
        Here we verify the case when starter nodes depend on get_attr node only.
        We don't expect any split to happen in this test, just want to make sure
        that the splitter code doesn't break.
        """

        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.a = torch.randn(2, 3)

            def forward(self):
                m = self.a + self.a
                o = m + m
                return o

        # No need to put anything on ACC.
        class TestOperatorSupport:
            def is_node_supported(self, submodules, node):
                return False

        module_original = acc_tracer.trace(TestModule(), torch.randn(2, 3))

        splitter = TRTSplitter(
            module=module_original,
            sample_input=torch.randn(2, 3),
            operator_support=TestOperatorSupport(),
        )

        def test_splitter(splitter):
            module_split = splitter()
            try:
                verify_split_model(module_split)
            except RuntimeError as err:
                self.assertEqual(
                    str(err), ERROR_MSG_NO_ACC_MODULE
                )

            # Check if modules are equivalent.
            result_original = module_original()
            result_split = module_split()
            self.assertTrue(torch.equal(result_original, result_split))

        test_splitter(splitter)
Exemplo n.º 11
0
    def test_split_complex_graph_2(self):
        module_nn = self.TestModule()
        module = acc_tracer.trace(module_nn, (torch.randn(2, 3),))

        # Making 'c', 'd' and 'e' run on ACC
        splitter = TRTSplitter(
            module,
            (torch.randn(2, 3),),
            op_support_with_support_dict(
                {
                    "acc_ops.cos": None,
                    "acc_ops.relu": None,
                    "acc_ops.add": None,
                }
            ),
        )

        def test_splitter(splitter):
            module_fx_split = splitter()
            verify_split_model(module_fx_split)

            [arg] = find_inputs(module)

            # First subgraph calculates b = sin(a) on CPU
            [sin] = find_fun_calls(module_fx_split._run_on_gpu_0, acc_ops.sin)
            self.assertEqual(arg.name, sin.kwargs["input"].name)

            # Second subgraph calculates c = relu(a), d = cos(a) and e = b + c on ACC
            [relu] = find_fun_calls(module_fx_split._run_on_acc_1, acc_ops.relu)
            self.assertEqual(arg.name, relu.kwargs["input"].name)

            [cos] = find_fun_calls(module_fx_split._run_on_acc_1, acc_ops.cos)
            self.assertEqual(arg.name, cos.kwargs["input"].name)

            [add] = find_fun_calls(module_fx_split._run_on_acc_1, acc_ops.add)
            self.assertEqual(sin.name, add.kwargs["input"].name)
            self.assertEqual(relu.name, add.kwargs["other"].name)

            # Third subgraph calculates f = e + d on CPU
            [sub] = find_fun_calls(module_fx_split._run_on_gpu_2, acc_ops.sub)
            self.assertEqual(add.name, sub.kwargs["input"].name)
            self.assertEqual(cos.name, sub.kwargs["other"].name)

        test_splitter(splitter)
Exemplo n.º 12
0
    def test_multi_output(self):
        class MultiOutputModule(torch.nn.Module):
            def forward(self, x):
                res, ind = torch.topk(x, 3)
                return torch.sigmoid(res), ind

        mod = acc_tracer.trace(MultiOutputModule(), torch.randn(2, 3))

        # Mark any operation as runnable on ACC
        class CustomOpSupport(op_support.OperatorSupportBase):
            def is_node_supported(self, submodules, node):
                return True

        splitter = TRTSplitter(
            mod, (torch.randn(2, 3),), CustomOpSupport()
        )

        def test_splitter(splitter):
            st_split = splitter()
            verify_split_model(st_split)
            [arg] = find_inputs(st_split)

            # There is only one subgraph that executes topk and sigmoid on ACC
            [topk] = find_fun_calls(st_split._run_on_acc_0, acc_ops.topk)
            self.assertEqual(arg.name, topk.kwargs["input"].name)
            self.assertEqual(3, topk.kwargs["k"])

            [topk_res1, topk_res2] = find_fun_calls(
                st_split._run_on_acc_0, acc_ops.getitem
            )

            [sigmoid] = find_fun_calls(st_split._run_on_acc_0, acc_ops.sigmoid)
            self.assertIn(
                sigmoid.kwargs["input"].name, {topk_res1.name, topk_res2.name}
            )

            # Main graph returns a tuple
            output = find_output(st_split._run_on_acc_0)
            self.assertLess(
                {output.args[0][0].name, output.args[0][1].name},
                {topk_res1.name, topk_res2.name, sigmoid.name},
            )

        test_splitter(splitter)
Exemplo n.º 13
0
        return x

inputs = [torch.randn(1, 10)]
model = Model().eval()

# acc_tracer is a custom fx tracer that maps nodes whose targets are PyTorch operators
# to acc ops.
traced = acc_tracer.trace(model, inputs)

# Splitter will split the model into serveral submodules. The name of submodules will
# be either `run_on_acc_{}` or `run_on_gpu_{}`. Submodules named `run_on_acc_{}` can
# be fully lowered to TensorRT via fx2trt while submodules named `run_on_gpu_{}` has
# unsupported ops and can't be lowered by fx2trt. We can still run `run_on_gpu_{}`
# submodules on Gpu if ops there have cuda implementation, the naming is a bit
# confusing and we'll improve it.
splitter = TRTSplitter(traced, inputs)

# Preview functionality allows us to see what are the supported ops and unsupported
# ops. We can optionally the dot graph which will color supported ops and unsupported
# ops differently.
splitter.node_support_preview(dump_graph=False)
"""
Supported node types in the model:
acc_ops.linear: ((), {'input': torch.float32, 'weight': torch.float32, 'bias': torch.float32})
acc_ops.relu: ((), {'input': torch.float32})

Unsupported node types in the model:
acc_ops.linalg_norm: ((), {'input': torch.float32})
"""

# Split.
Exemplo n.º 14
0
    def test_extend_acc_subgraph_after_split(self):
        class TestModule(torch.nn.Module):
            r"""     a (input)
                     |
                     b
                    / \
                   c   d
                    \ /
                     e
                    / \
                   |   (g1, g2, g3, g4)
                    \ / |
                     f  |
                      \ |
                       h

            c and f are not runnable on acc while all other nodes are supported by acc.
            g1, g2, g3 and g4 should be in a fusion group, let's call it g.

            After split we have 2 cpu subgraphs (c) and (f), 3 acc subgraphs (b, d), (e, g) and (h).
            We expect 3 acc subgraphs (b), (d, e, g) and (h) after extend the second acc subgraph.
            And expect acc subgraphs stay the same after extend the third acc subgraph because of
            the unbreakable fusion group.
            """
            def forward(self, a: torch.Tensor):
                b = a + a
                c = b - b
                d = b + b
                e = c + d

                # These four nodes should be in a fusion group
                g1 = e.size()
                g2 = g1[0]
                g3 = e + g2
                g4 = g3 + g2

                f = e - g3
                h = f + g4
                return h

        a = torch.randn(2)
        mod = acc_tracer.trace(TestModule(), (a, ))

        # Allow all nodes expect subtract run on accelerator
        class CustomOpSupport(op_support.OperatorSupportBase):
            def is_node_supported(self, submodules, node):
                return op_support.get_node_target(submodules,
                                                  node) != "acc_ops.sub"

        splitter = TRTSplitter(mod, (a, ), CustomOpSupport())

        def test_splitter(splitter):
            # Manually tag nodes first in case split algorithm changes in the future
            nodes = list(splitter.module.graph.nodes)
            # b and d
            nodes[1].tag = "acc_0"
            nodes[3].tag = "acc_0"
            # c
            nodes[2].tag = "cpu_1"
            # e and g
            nodes[4].tag = "acc_2"
            nodes[5].tag = "acc_2"
            nodes[6].tag = "acc_2"
            nodes[7].tag = "acc_2"
            nodes[8].tag = "acc_2"
            # f
            nodes[9].tag = "cpu_3"
            # h
            nodes[10].tag = "acc_4"

            splitter.tags = ["acc_0", "cpu_1", "acc_2", "cpu_3", "acc_4"]
            split_module = splitter.split()
            try:
                verify_split_model(split_module, "acc_")
            except RuntimeError as err:
                self.assertEqual(str(err), ERROR_MSG_MULTI_ACC_MODULES)
            try:
                verify_split_model(split_module)
            except RuntimeError as err:
                self.assertEqual(str(err), ERROR_MSG_NO_ACC_MODULE)

            module_names = [name for name, _ in split_module.named_modules()]
            # Main module, 2 cpu submodules and 3 acc submodule
            assert len(module_names) == 6

            # 1 Placeholder, 2 Adds and 1 Output
            assert len(split_module.acc_0.graph.nodes) == 4
            # 2 Placeholder, 3 Adds, 1 Size, 1 GetItem and 1 Output
            assert len(split_module.acc_2.graph.nodes) == 8

            # Extend the second acc subgraph
            splitter.extend_acc_subgraph("acc_2")
            extend_module = splitter.split()
            try:
                verify_split_model(extend_module, "acc_")
            except RuntimeError as err:
                self.assertEqual(str(err), ERROR_MSG_MULTI_ACC_MODULES)

            # 1 Placeholder, 1 Adds and 1 Output
            assert len(extend_module.acc_0.graph.nodes) == 3
            # 2 Placeholder, 4 Adds 1 Size, 1 GetItem and 1 Output
            assert len(extend_module.acc_2.graph.nodes) == 9

            # Extend the third acc subgraph
            splitter.extend_acc_subgraph("acc_4")
            extend_module = splitter.split()
            try:
                verify_split_model(extend_module, "acc_")
            except RuntimeError as err:
                self.assertEqual(str(err), ERROR_MSG_MULTI_ACC_MODULES)

            assert len(extend_module.acc_2.graph.nodes) == 9
            # 2 Placeholder, 1 Adds and 1 Output
            assert len(extend_module.acc_4.graph.nodes) == 4

        test_splitter(splitter)
Exemplo n.º 15
0
    def test_nested_modules(self):
        """
                x
             //   \\
            //     \\
        relu(x)    sin(x)
            \\     //
             \\   //
         relu(x) + sin(x)
        """
        class ReluModule(torch.nn.Module):
            def forward(self, x):
                return torch.relu(x)

        class SinModule(torch.nn.Module):
            def forward(self, x):
                return torch.sin(x)

        class TestModule3(torch.nn.Module):
            def __init__(self, relu_module, sin_module):
                super().__init__()
                self.relu_module = relu_module
                self.sin_module = sin_module

            def forward(self, x):
                return self.relu_module(x) + self.sin_module(x)

        mod = acc_tracer.trace(TestModule3(ReluModule(), SinModule()),
                               torch.randn(2, 3))

        # Making sin(x) run on ACC
        splitter = TRTSplitter(
            mod,
            (torch.randn(2, 3), ),
            op_support_with_support_dict({
                "acc_ops.sin": None,
            }),
        )

        def test_splitter(splitter):
            st_split = splitter()
            verify_split_model(st_split)
            [arg] = find_inputs(st_split)

            # First subgraph calculates relu(x) on CPU
            [relu] = find_fun_calls(st_split._run_on_cpu_0, acc_ops.relu)
            self.assertEqual(arg.name, relu.kwargs["input"].name)

            # Second subgraph calculates sin(x) on ACC
            [sin] = find_fun_calls(st_split._run_on_acc_1, acc_ops.sin)
            self.assertEqual(arg.name, sin.kwargs["input"].name)

            # Third subgraph calculates sum on CPU
            [add] = find_fun_calls(st_split._run_on_cpu_2, acc_ops.add)
            self.assertEqual(relu.name, add.kwargs["input"].name)
            self.assertEqual(sin.name, add.kwargs["other"].name)

            # Checking that results of applying split module will be the same
            tensor = torch.randn(5)
            self.assertTrue(torch.equal(mod(tensor), st_split(tensor)))

        test_splitter(splitter)