Example #1
0
 def create(
     cls,
     use_implicit_batch_dim: bool,
     min_acc_module_size: int = 20,
 ):
     return Splitter(
         min_acc_module_size=min_acc_module_size,
         operator_supported=create_trt_operator_support(
             use_implicit_batch_dim),
     )
    def test_unsupport_node_implicit_batch_dim(self):
        class TestModule(nn.Module):
            def forward(self, x):
                y = torch.add(input=x, other=x)
                return nn.functional.gelu(y)

        mod = TestModule()
        traced_mod = acc_tracer.trace(mod, torch.randn(5, 2))
        op_support = create_trt_operator_support(use_implicit_batch_dim=True)

        for node in traced_mod.graph.nodes:
            if node.target == acc_ops.add:
                self.assertTrue(op_support.is_node_supported(mod, node))
            elif node.target == acc_ops.gelu:
                self.assertFalse(op_support.is_node_supported(mod, node))
Example #3
0
    def test_support_node_with_int_attr(self):
        class TestModule(nn.Module):
            def forward(self, x):
                zeros = torch.randint(3, 5, (1, ))
                zeros = zeros.to(torch.int64)
                scale = torch.randn(1)
                return torch.quantize_per_tensor(x, scale, zeros, torch.quint8)

        mod = TestModule()
        traced_mod = acc_tracer.trace(mod, torch.randn(5, 2))
        op_support = create_trt_operator_support(use_implicit_batch_dim=True)

        for node in traced_mod.graph.nodes:
            if node.target == acc_ops.quantize_per_tensor:
                self.assertTrue(op_support.is_node_supported(mod, node))
    def test_supported_node_target(self):
        class TestModule(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.Linear(1, 1)

            def forward(self, x):
                x = self.linear(x)
                x = x + 1
                return torch.add(input=x, other=x)

        mod = TestModule()
        traced_mod = acc_tracer.trace(mod, torch.randn(1, 2, 1, 1))
        op_support = create_trt_operator_support()
        for node in traced_mod.graph.nodes:
            self.assertTrue(op_support.is_node_supported(mod, node))
Example #5
0
class Splitter(SplitFunc):
    """A composable fx2trtr splitter.

    See `SplitFunc`.

    Attributes:
        min_acc_module_size: minimum split module size
    """

    _INPUT_ATTR: t.ClassVar[str] = "_split_graph_recorded_input"
    min_acc_module_size: int
    operator_supported: OperatorSupportBase = dc.field(
        default_factory=lambda: create_trt_operator_support())

    @classmethod
    def create(
        cls,
        use_implicit_batch_dim: bool,
        min_acc_module_size: int = 20,
    ):
        return Splitter(
            min_acc_module_size=min_acc_module_size,
            operator_supported=create_trt_operator_support(
                use_implicit_batch_dim),
        )

    def __call__(self, module,
                 input) -> Tuple[fx.GraphModule, t.Sequence[SplitInfo]]:
        trt_split_result = self._trt_split(module, input)

        logger.debug(f"""TRT split result graph >>> {
                trt_split_result.graph
            }""")

        Splitter._propagate_split_inputs(
            trt_split_result, input,
            dict(trt_split_result.named_children()).keys())

        return (
            trt_split_result,
            [
                Splitter._create_split_info(name,
                                            subgraph,
                                            parent=trt_split_result)
                for name, subgraph in trt_split_result.named_children()
            ],
        )

    @classmethod
    def _propagate_split_inputs(
        cls,
        graph: fx.GraphModule,
        input: Input,
        target_modules: t.Collection[str],
    ) -> None:
        """
        Input propagation on subnets

        TODO: refactor so we don't set inputs onto the subgraphs
        """
        handles = []

        def pre_forward(mod, input):
            setattr(mod, cls._INPUT_ATTR, input)

        def _install_hook(g):
            nonlocal handles
            if not g:
                return
            for _n, _g in g.named_children():
                if _n in target_modules:
                    handles.append(_g.register_forward_pre_hook(pre_forward))
                    _install_hook(_g)

        try:
            _install_hook(graph)
            graph(*input)
        finally:
            for h in handles:
                h.remove()

    @classmethod
    def _create_split_info(cls, name, graph, parent) -> SplitInfo:
        device, order = cls._parse_splitter_subgraph_name(name)
        input = getattr(graph, cls._INPUT_ATTR)
        delattr(graph, cls._INPUT_ATTR)
        return SplitInfo(
            module=graph,
            input=input,
            name=name,
            device=device,
            order=order,
        )

    @classmethod
    def _parse_splitter_subgraph_name(cls, name: str) -> t.Tuple[str, int]:
        match = re.match("_run_on_([a-z]+)_([0-9]+)", name)
        assert match, f"{name} doesn't comform with splitter subgraph naming convention"
        return (match[1], int(match[2]))

    def _trt_split(self, graph: fx.GraphModule,
                   input: Input) -> fx.GraphModule:
        splitter_settings = _SplitterSettingBase()
        splitter_settings.min_acc_module_size = self.min_acc_module_size

        splitter = TRTSplitter(
            graph,
            input,  # type: ignore[arg-type]
            self.operator_supported,
            settings=splitter_settings,
        )
        logger.info(f"""{splitter.node_support_preview.__name__}: {
            splitter.node_support_preview()
            }""")
        return splitter()