def test_start_with_acc_module_(self): """ sin relu cos sigmoid tanh a ====> b =====> c ====> d ========> e =====> f We set sin, relu and cos as acc node but also set min_acc_module_size to 2 and expect the whole module stay on CPU. """ class TestModule(torch.nn.Module): def forward(self, a): b = torch.sin(a) c = torch.relu(b) d = torch.cos(c) e = torch.sigmoid(d) f = torch.tanh(e) return f mod = acc_tracer.trace(TestModule(), torch.randn(2, 3)) # Set sin, cos and tanh as acc node and split with settings class CustomOpSupport(op_support.OperatorSupport): _support_dict = { "acc_ops.sin": None, "acc_ops.cos": None, "acc_ops.relu": None, } # Create splitter setting and set min_acc_module_size to 2 settings = splitter_base._SplitterSettingBase() settings.min_acc_module_size = 2 splitter = TRTSplitter( mod, (torch.randn(2, 3),), op_support_with_support_dict( { "acc_ops.sin": None, "acc_ops.cos": None, "acc_ops.relu": None, } ), settings, ) def test_splitter(splitter): st_split = splitter() try: verify_split_model(st_split) except RuntimeError as err: self.assertEqual( str(err), ERROR_MSG_NO_ACC_MODULE ) modules = list(st_split.named_modules()) # Main module and a submodule assert len(modules) == 3 assert modules[1][0] == "_run_on_acc_0" assert modules[2][0] == "_run_on_gpu_1" test_splitter(splitter)
def __init__( self, module: torch.fx.GraphModule, sample_input: Tuple[torch.Tensor], operator_support: OperatorSupport = None, settings: splitter_base._SplitterSettingBase = None, ): if not operator_support: operator_support = TRTOperatorSupport() if not settings: settings = splitter_base._SplitterSettingBase() super().__init__(module, sample_input, operator_support, settings)
def __init__(self, module: torch.fx.GraphModule, sample_input: Tuple[torch.Tensor], operator_support: op_support.OperatorSupport = None, settings: splitter_base._SplitterSettingBase = None): if not operator_support: operator_support = op_support.OperatorSupport() if not settings: settings = splitter_base._SplitterSettingBase() settings.allow_non_tensor = True settings.skip_fusion = True super().__init__(module, sample_input, operator_support, settings)
def _trt_split(self, graph: fx.GraphModule, input: Input) -> fx.GraphModule: splitter_settings = _SplitterSettingBase() splitter_settings.min_acc_module_size = self.min_acc_module_size splitter = TRTSplitter( graph, input, # type: ignore[arg-type] self.operator_supported, settings=splitter_settings, ) logger.info(f"""{splitter.node_support_preview.__name__}: { splitter.node_support_preview() }""") return splitter()
def __init__( self, module: torch.fx.GraphModule, sample_input: Tuple[torch.Tensor], operator_support: ops.OperatorSupportBase = None, settings: splitter_base._SplitterSettingBase = None, ): if not operator_support: operator_support = create_trt_operator_support() if not settings: settings = splitter_base._SplitterSettingBase() super().__init__(module, sample_input, operator_support, settings, non_acc_submodule_name="_run_on_gpu_")
def test_split_non_tensor_edges_4(self): test_data = torch.randn(2, 3) module_nn = acc_tracer.trace( self.TestModule(), (test_data, ), ) # Making 'a', 'c', 'd' and 'e' run on ACC with limit on ACC # subgraph size settings = splitter_base._SplitterSettingBase() settings.min_acc_module_size = 2 splitter = TRTSplitter( module_nn, (test_data, ), op_support_with_support_dict({ "acc_ops.relu": None, "acc_ops.sigmoid": None, "acc_ops.cos": None, "acc_ops.add": None, }), settings, ) def test_splitter(splitter): module_fx_split = splitter() verify_split_model(module_fx_split) self.assertEqual( {acc_ops.relu, acc_ops.cos}, find_call_targets(module_fx_split._run_on_acc_0), ) self.assertEqual( {acc_ops.size, acc_ops.getitem, acc_ops.add, acc_ops.sigmoid}, find_call_targets(module_fx_split._run_on_cpu_1), ) # Make sure we can compile to TorchScript module_jit = torch.jit.trace_module(module_fx_split, {"forward": test_data}) self.assertTrue( torch.allclose(module_nn(test_data), module_jit(test_data))) test_splitter(splitter)