Example #1
0
    def test_start_with_acc_module_(self):
        """
           sin     relu     cos     sigmoid     tanh
        a ====> b =====> c ====> d ========> e =====> f

        We set sin, relu and cos as acc node but also set min_acc_module_size to 2
        and expect the whole module stay on CPU.
        """

        class TestModule(torch.nn.Module):
            def forward(self, a):
                b = torch.sin(a)
                c = torch.relu(b)
                d = torch.cos(c)
                e = torch.sigmoid(d)
                f = torch.tanh(e)
                return f

        mod = acc_tracer.trace(TestModule(), torch.randn(2, 3))

        # Set sin, cos and tanh as acc node and split with settings
        class CustomOpSupport(op_support.OperatorSupport):
            _support_dict = {
                "acc_ops.sin": None,
                "acc_ops.cos": None,
                "acc_ops.relu": None,
            }

        # Create splitter setting and set min_acc_module_size to 2
        settings = splitter_base._SplitterSettingBase()
        settings.min_acc_module_size = 2
        splitter = TRTSplitter(
            mod,
            (torch.randn(2, 3),),
            op_support_with_support_dict(
                {
                    "acc_ops.sin": None,
                    "acc_ops.cos": None,
                    "acc_ops.relu": None,
                }
            ),
            settings,
        )

        def test_splitter(splitter):
            st_split = splitter()
            try:
                verify_split_model(st_split)
            except RuntimeError as err:
                self.assertEqual(
                    str(err), ERROR_MSG_NO_ACC_MODULE
                )
            modules = list(st_split.named_modules())
            # Main module and a submodule
            assert len(modules) == 3

            assert modules[1][0] == "_run_on_acc_0"
            assert modules[2][0] == "_run_on_gpu_1"

        test_splitter(splitter)
Example #2
0
 def __init__(
     self,
     module: torch.fx.GraphModule,
     sample_input: Tuple[torch.Tensor],
     operator_support: OperatorSupport = None,
     settings: splitter_base._SplitterSettingBase = None,
 ):
     if not operator_support:
         operator_support = TRTOperatorSupport()
     if not settings:
         settings = splitter_base._SplitterSettingBase()
     super().__init__(module, sample_input, operator_support, settings)
    def __init__(self,
                 module: torch.fx.GraphModule,
                 sample_input: Tuple[torch.Tensor],
                 operator_support: op_support.OperatorSupport = None,
                 settings: splitter_base._SplitterSettingBase = None):
        if not operator_support:
            operator_support = op_support.OperatorSupport()

        if not settings:
            settings = splitter_base._SplitterSettingBase()
            settings.allow_non_tensor = True
            settings.skip_fusion = True

        super().__init__(module, sample_input, operator_support, settings)
Example #4
0
    def _trt_split(self, graph: fx.GraphModule,
                   input: Input) -> fx.GraphModule:
        splitter_settings = _SplitterSettingBase()
        splitter_settings.min_acc_module_size = self.min_acc_module_size

        splitter = TRTSplitter(
            graph,
            input,  # type: ignore[arg-type]
            self.operator_supported,
            settings=splitter_settings,
        )
        logger.info(f"""{splitter.node_support_preview.__name__}: {
            splitter.node_support_preview()
            }""")
        return splitter()
Example #5
0
 def __init__(
     self,
     module: torch.fx.GraphModule,
     sample_input: Tuple[torch.Tensor],
     operator_support: ops.OperatorSupportBase = None,
     settings: splitter_base._SplitterSettingBase = None,
 ):
     if not operator_support:
         operator_support = create_trt_operator_support()
     if not settings:
         settings = splitter_base._SplitterSettingBase()
     super().__init__(module,
                      sample_input,
                      operator_support,
                      settings,
                      non_acc_submodule_name="_run_on_gpu_")
Example #6
0
    def test_split_non_tensor_edges_4(self):
        test_data = torch.randn(2, 3)

        module_nn = acc_tracer.trace(
            self.TestModule(),
            (test_data, ),
        )

        # Making 'a', 'c', 'd' and 'e' run on ACC with limit on ACC
        # subgraph size
        settings = splitter_base._SplitterSettingBase()
        settings.min_acc_module_size = 2
        splitter = TRTSplitter(
            module_nn,
            (test_data, ),
            op_support_with_support_dict({
                "acc_ops.relu": None,
                "acc_ops.sigmoid": None,
                "acc_ops.cos": None,
                "acc_ops.add": None,
            }),
            settings,
        )

        def test_splitter(splitter):
            module_fx_split = splitter()
            verify_split_model(module_fx_split)

            self.assertEqual(
                {acc_ops.relu, acc_ops.cos},
                find_call_targets(module_fx_split._run_on_acc_0),
            )

            self.assertEqual(
                {acc_ops.size, acc_ops.getitem, acc_ops.add, acc_ops.sigmoid},
                find_call_targets(module_fx_split._run_on_cpu_1),
            )

            # Make sure we can compile to TorchScript
            module_jit = torch.jit.trace_module(module_fx_split,
                                                {"forward": test_data})
            self.assertTrue(
                torch.allclose(module_nn(test_data), module_jit(test_data)))

        test_splitter(splitter)