def build_int8_trt(rn18):
    rn18 = copy.deepcopy(rn18)
    data = torch.randn(1, 3, 224, 224)
    # data = torch.randn(1, 32)
    # data = torch.randn(1, 64, 10, 10)
    # TensorRT only supports symmetric quantization
    qconfig = torch.quantization.QConfig(
        activation=torch.quantization.observer.HistogramObserver.with_args(
            qscheme=torch.per_tensor_symmetric, dtype=torch.qint8),
        weight=torch.quantization.default_weight_observer)
    prepared = prepare_fx(rn18, {"": qconfig})
    for _ in range(10):
        prepared(data)
    quantized_rn18 = convert_fx(prepared, is_reference=True)
    ref_res = quantized_rn18(data)
    print("quantized model:", quantized_rn18)

    quantized_rn18 = acc_tracer.trace(quantized_rn18,
                                      [data])  # type: ignore[attr-defined]
    interp = TRTInterpreter(quantized_rn18, [
        InputTensorSpec(torch.Size([-1, *data.shape[1:]]),
                        torch.float,
                        shape_ranges=[((1, 3, 224, 224), (5, 3, 224, 224),
                                       (10, 3, 224, 224))],
                        has_batch_dim=True)
    ],
                            explicit_batch_dimension=True,
                            explicit_precision=True)
    engine, input_names, output_names = interp.run(fp16_mode=False,
                                                   int8_mode=True)
    trt_mod = TRTModule(engine, input_names, output_names)
    trt_res = trt_mod(data.cuda())
    print("result diff max", torch.max(ref_res - trt_res.cpu()))
    return trt_mod
Example #2
0
    def test_longer_chain(self):
        """
           sin     relu     cos     sigmoid     tanh
        a ====> b =====> c ====> d ========> e =====> f
        """

        class TestModule(torch.nn.Module):
            def forward(self, a):
                b = torch.sin(a)
                c = torch.relu(b)
                d = torch.cos(c)
                e = torch.sigmoid(d)
                f = torch.tanh(e)
                return f

        mod = acc_tracer.trace(TestModule(), torch.randn(2, 3))

        # Making relu and sigmoid execute on ACC
        splitter = TRTSplitter(
            mod,
            (torch.randn(2, 3),),
            op_support_with_support_dict(
                {
                    "acc_ops.relu": None,
                    "acc_ops.sigmoid": None,
                }
            ),
        )

        def test_splitter(splitter):
            st_split = splitter()
            try:
                verify_split_model(st_split)
            except RuntimeError as err:
                self.assertEqual(
                    str(err), ERROR_MSG_MULTI_ACC_MODULES
                )
            [arg] = find_inputs(st_split)

            # First subgraph calculates b = sin(a) on CPU
            [sin] = find_fun_calls(st_split._run_on_gpu_0, acc_ops.sin)
            self.assertEqual(arg.name, sin.kwargs["input"].name)

            # Second subgraph calculates c = relu(b) on ACC
            [relu] = find_fun_calls(st_split._run_on_acc_1, acc_ops.relu)
            self.assertEqual(sin.name, relu.kwargs["input"].name)

            # Third subgraph calculates d = cos(c) on CPU
            [cos] = find_fun_calls(st_split._run_on_gpu_2, acc_ops.cos)
            self.assertEqual(relu.name, cos.kwargs["input"].name)

            # Fourth subgraph calculates e = sigmoid(d) on ACC
            [sigmoid] = find_fun_calls(st_split._run_on_acc_3, acc_ops.sigmoid)
            self.assertEqual(cos.name, sigmoid.kwargs["input"].name)

            # Fifth subgraph calculates f = tanh(e) on CPU
            [tanh] = find_fun_calls(st_split._run_on_gpu_4, acc_ops.tanh)
            self.assertEqual(sigmoid.name, tanh.kwargs["input"].name)

        test_splitter(splitter)
Example #3
0
    def test_nothing_to_split(self):
        class SimpleModule(torch.nn.Module):
            def forward(self, a):
                return a

        mod = acc_tracer.trace(SimpleModule(), torch.randn(2, 3))

        # Mark any operation as runnable on ACC
        class CustomOpSupport(op_support.OperatorSupportBase):
            def is_node_supported(self, submodules, node):
                return True

        splitter = TRTSplitter(
            mod, (torch.randn(2, 3),), CustomOpSupport()
        )

        def test_splitter(splitter):
            st_split = splitter()
            try:
                verify_split_model(st_split)
            except RuntimeError as err:
                self.assertEqual(
                    str(err), ERROR_MSG_NO_ACC_MODULE
                )
            self.assertEqual(splitter.module.__dict__.keys(), st_split.__dict__.keys())

        test_splitter(splitter)
Example #4
0
    def test_start_with_acc_module_(self):
        """
           sin     relu     cos     sigmoid     tanh
        a ====> b =====> c ====> d ========> e =====> f

        We set sin, relu and cos as acc node but also set min_acc_module_size to 2
        and expect the whole module stay on CPU.
        """

        class TestModule(torch.nn.Module):
            def forward(self, a):
                b = torch.sin(a)
                c = torch.relu(b)
                d = torch.cos(c)
                e = torch.sigmoid(d)
                f = torch.tanh(e)
                return f

        mod = acc_tracer.trace(TestModule(), torch.randn(2, 3))

        # Set sin, cos and tanh as acc node and split with settings
        class CustomOpSupport(op_support.OperatorSupport):
            _support_dict = {
                "acc_ops.sin": None,
                "acc_ops.cos": None,
                "acc_ops.relu": None,
            }

        # Create splitter setting and set min_acc_module_size to 2
        settings = splitter_base._SplitterSettingBase()
        settings.min_acc_module_size = 2
        splitter = TRTSplitter(
            mod,
            (torch.randn(2, 3),),
            op_support_with_support_dict(
                {
                    "acc_ops.sin": None,
                    "acc_ops.cos": None,
                    "acc_ops.relu": None,
                }
            ),
            settings,
        )

        def test_splitter(splitter):
            st_split = splitter()
            try:
                verify_split_model(st_split)
            except RuntimeError as err:
                self.assertEqual(
                    str(err), ERROR_MSG_NO_ACC_MODULE
                )
            modules = list(st_split.named_modules())
            # Main module and a submodule
            assert len(modules) == 3

            assert modules[1][0] == "_run_on_acc_0"
            assert modules[2][0] == "_run_on_gpu_1"

        test_splitter(splitter)
Example #5
0
    def run_test(
        self,
        mod,
        inputs,
        expected_ops,
        apply_passes=None,
        test_explicit_batch_dim=True,
        test_implicit_batch_dim=True,
        rtol=1e-03,
        atol=1e-03,
    ):
        mod.eval()
        mod = acc_tracer.trace(mod, inputs)

        if apply_passes is not None:
            for p in apply_passes:
                mod = p(mod)

        if test_implicit_batch_dim:
            interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
            super().run_test(mod, inputs, expected_ops, interp, rtol, atol)

        if test_explicit_batch_dim:
            interp = TRTInterpreter(mod,
                                    InputTensorSpec.from_tensors(inputs),
                                    explicit_batch_dimension=True)
            super().run_test(mod, inputs, expected_ops, interp, rtol, atol)
def build_fp16_trt(rn18):
    rn18 = copy.deepcopy(rn18)
    rn18 = acc_tracer.trace(rn18, [torch.randn(1, 3, 224, 224)])
    interp = TRTInterpreter(
        rn18, [InputTensorSpec(torch.Size([3, 224, 224]), torch.float, has_batch_dim=False)])
    interpreter_result = interp.run(fp16_mode=True)
    return TRTModule(interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names)
def build_fp16_trt(rn18):
    rn18 = copy.deepcopy(rn18)
    rn18 = acc_tracer.trace(
        rn18, [torch.randn(1, 3, 224, 224)])  # type: ignore[attr-defined]
    interp = TRTInterpreter(rn18, [
        InputTensorSpec(
            torch.Size([3, 224, 224]), torch.float, has_batch_dim=False)
    ])
    engine, input_names, output_names = interp.run(fp16_mode=True)
    return TRTModule(engine, input_names, output_names)
Example #8
0
def lower_to_trt(model, sample_input, shape_ranges):
    model = acc_tracer.trace(model, [sample_input])  # type: ignore[attr-defined]
    interp = TRTInterpreter(
        model,
        [InputTensorSpec(
            torch.Size([-1, *sample_input.shape[1:]]), torch.float,
            shape_ranges=shape_ranges, has_batch_dim=True)],
        explicit_batch_dimension=True, explicit_precision=True)
    engine, input_names, output_names = interp.run(fp16_mode=False, int8_mode=True)
    trt_mod = TRTModule(engine, input_names, output_names)
    return trt_mod
Example #9
0
    def test_mod_with_getattr(self):
        """
        CPU subgraph should have get_attr for self.a while ACC subgraph
        should have get_attr for self.b.
        """

        class SimpleModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.a = torch.randn(1, 1, 1, 1)
                self.b = torch.randn(1, 1, 1, 1)
                self.conv = torch.nn.Conv2d(1, 1, 1)
                self.linear = torch.nn.Linear(1, 1)

            def forward(self, x):
                x = x + self.a
                x = self.conv(x)
                return self.linear(x - self.b)

        mod = acc_tracer.trace(SimpleModule(), torch.randn(1, 1, 1, 1))
        mod.eval()

        splitter = TRTSplitter(
            mod,
            (torch.randn(1, 1, 1, 1),),
            op_support_with_support_dict(
                {
                    "acc_ops.linear": None,
                    "acc_ops.sub": None,
                }
            ),
        )

        def test_splitter(splitter):
            st_split = splitter()
            verify_split_model(st_split)
            # Should be "a", "conv.weight", "conv.bias".
            get_attr_nodes = [
                node.target
                for node in st_split._run_on_gpu_0.graph.nodes
                if node.op == "get_attr"
            ]
            assert len(get_attr_nodes) == 3 and "a" in get_attr_nodes

            # Should be "b", "conv.weight", "conv.bias".
            get_attr_nodes = [
                node.target
                for node in st_split._run_on_acc_1.graph.nodes
                if node.op == "get_attr"
            ]
            assert len(get_attr_nodes) == 3 and "b" in get_attr_nodes

        test_splitter(splitter)
Example #10
0
 def run_test_with_dynamic_shape(
     self,
     mod,
     input_specs,
     expected_ops,
     rtol=1e-03,
     atol=1e-03,
 ):
     mod.eval()
     inputs = create_inputs_from_specs(input_specs)
     mod = acc_tracer.trace(mod, inputs)
     interp = TRTInterpreter(mod, input_specs, explicit_batch_dimension=True)
     super().run_test(mod, inputs, expected_ops, interp, rtol, atol)
Example #11
0
    def create(
        cls,
        lower_setting: LowerSetting,
    ) -> "Lowerer":
        """Instantiate a `Lowerer` instance."""

        return Lowerer(
            split=Splitter.create(not lower_setting.explicit_batch_dimension),
            acc_trace=lambda mod, input: acc_tracer.trace(mod, input),  # type: ignore[arg-type]
            remove_duplicate_output_args=remove_duplicate_output_args,
            trt_interpreter=LowerTrtInterpreter.create(lower_setting),
            fp16=lower_setting.fp16_mode,
        )
Example #12
0
    def test_split_non_tensor_edges_3(self):
        test_data = torch.randn(2, 3)

        module_nn = acc_tracer.trace(
            self.TestModule(),
            (test_data, ),
        )

        # Making 'a', 'c', 'd' and 'e' run on ACC
        splitter = TRTSplitter(
            module_nn,
            (test_data, ),
            op_support_with_support_dict({
                "acc_ops.relu": None,
                "acc_ops.sigmoid": None,
                "acc_ops.cos": None,
                "acc_ops.add": None,
            }),
        )

        def test_splitter(splitter):
            module_fx_split = splitter()
            try:
                verify_split_model(module_fx_split)
            except RuntimeError as err:
                self.assertEqual(str(err), ERROR_MSG_MULTI_ACC_MODULES)

            self.assertEqual(
                {acc_ops.relu, acc_ops.cos},
                find_call_targets(module_fx_split._run_on_acc_0),
            )

            self.assertEqual(
                {acc_ops.size, acc_ops.getitem, acc_ops.add},
                find_call_targets(module_fx_split._run_on_cpu_1),
            )

            self.assertEqual(
                {acc_ops.sigmoid},
                find_call_targets(module_fx_split._run_on_acc_2),
            )

            # Make sure we can compile to TorchScript
            module_jit = torch.jit.trace_module(module_fx_split,
                                                {"forward": test_data})
            self.assertTrue(
                torch.allclose(module_nn(test_data), module_jit(test_data)))

        test_splitter(splitter)
Example #13
0
    def test_get_attr_into_output(self):
        """
        Here we verify the case when get_attr node is consumed directly by the
        output. We don't expect any split to happen in this test, just want to
        make sure that the splitter code doesn't break.
        """

        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.a = torch.randn(2, 3)

            def forward(self, x):
                return (x, self.a)

        # No need to put anything on ACC.
        class TestOperatorSupport:
            def is_node_supported(self, submodules, node):
                return False

        module_original = acc_tracer.trace(TestModule(), torch.randn(4, 5))

        splitter = TRTSplitter(
            module=module_original,
            sample_input=torch.randn(4, 5),
            operator_support=TestOperatorSupport(),
        )

        def test_splitter(splitter):
            module_split = splitter()
            try:
                verify_split_model(module_split)
            except RuntimeError as err:
                self.assertEqual(
                    str(err), ERROR_MSG_NO_ACC_MODULE
                )

            output = find_output(module_split)
            # Second argument of the output should be get_attr.
            self.assertEqual("get_attr", output.args[0][1].op)

            # Check if modules are equivalent.
            tensor = torch.randn(10, 20)
            result_original = module_original(tensor)
            result_split = module_split(tensor)
            self.assertTrue(torch.equal(result_original[0], result_split[0]))
            self.assertTrue(torch.equal(result_original[1], result_split[1]))

        test_splitter(splitter)
Example #14
0
    def test_support_node_with_int_attr(self):
        class TestModule(nn.Module):
            def forward(self, x):
                zeros = torch.randint(3, 5, (1, ))
                zeros = zeros.to(torch.int64)
                scale = torch.randn(1)
                return torch.quantize_per_tensor(x, scale, zeros, torch.quint8)

        mod = TestModule()
        traced_mod = acc_tracer.trace(mod, torch.randn(5, 2))
        op_support = create_trt_operator_support(use_implicit_batch_dim=True)

        for node in traced_mod.graph.nodes:
            if node.target == acc_ops.quantize_per_tensor:
                self.assertTrue(op_support.is_node_supported(mod, node))
    def test_unsupport_node_implicit_batch_dim(self):
        class TestModule(nn.Module):
            def forward(self, x):
                y = torch.add(input=x, other=x)
                return nn.functional.gelu(y)

        mod = TestModule()
        traced_mod = acc_tracer.trace(mod, torch.randn(5, 2))
        op_support = create_trt_operator_support(use_implicit_batch_dim=True)

        for node in traced_mod.graph.nodes:
            if node.target == acc_ops.add:
                self.assertTrue(op_support.is_node_supported(mod, node))
            elif node.target == acc_ops.gelu:
                self.assertFalse(op_support.is_node_supported(mod, node))
    def test_supported_node_target(self):
        class TestModule(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.Linear(1, 1)

            def forward(self, x):
                x = self.linear(x)
                x = x + 1
                return torch.add(input=x, other=x)

        mod = TestModule()
        traced_mod = acc_tracer.trace(mod, torch.randn(1, 2, 1, 1))
        op_support = create_trt_operator_support()
        for node in traced_mod.graph.nodes:
            self.assertTrue(op_support.is_node_supported(mod, node))
Example #17
0
    def test_split_non_tensor_edges_4(self):
        test_data = torch.randn(2, 3)

        module_nn = acc_tracer.trace(
            self.TestModule(),
            (test_data, ),
        )

        # Making 'a', 'c', 'd' and 'e' run on ACC with limit on ACC
        # subgraph size
        settings = splitter_base._SplitterSettingBase()
        settings.min_acc_module_size = 2
        splitter = TRTSplitter(
            module_nn,
            (test_data, ),
            op_support_with_support_dict({
                "acc_ops.relu": None,
                "acc_ops.sigmoid": None,
                "acc_ops.cos": None,
                "acc_ops.add": None,
            }),
            settings,
        )

        def test_splitter(splitter):
            module_fx_split = splitter()
            verify_split_model(module_fx_split)

            self.assertEqual(
                {acc_ops.relu, acc_ops.cos},
                find_call_targets(module_fx_split._run_on_acc_0),
            )

            self.assertEqual(
                {acc_ops.size, acc_ops.getitem, acc_ops.add, acc_ops.sigmoid},
                find_call_targets(module_fx_split._run_on_cpu_1),
            )

            # Make sure we can compile to TorchScript
            module_jit = torch.jit.trace_module(module_fx_split,
                                                {"forward": test_data})
            self.assertTrue(
                torch.allclose(module_nn(test_data), module_jit(test_data)))

        test_splitter(splitter)
Example #18
0
    def test_demo(self):
        """
          ==> b ==>
        //         \\
       a             d
        \\         //
          ==> c ==>
        """

        class SimpleModule(torch.nn.Module):
            def forward(self, a):
                b = torch.sin(a)
                c = torch.cos(a)
                d = b + c
                return d

        mod = acc_tracer.trace(SimpleModule(), torch.randn(2, 3))

        # Making b and c run on ACC
        splitter = TRTSplitter(
            mod,
            (torch.randn(2, 3),),
            op_support_with_support_dict(
                {
                    "acc_ops.sin": None,
                    "acc_ops.cos": None,
                }
            ),
        )

        st_split = splitter()

        [arg] = find_inputs(st_split)

        # First subgraph calculates b = sin(a) and c = cos(a) on ACC
        [sin] = find_fun_calls(st_split._run_on_acc_0, acc_ops.sin)
        self.assertEqual(arg.name, sin.kwargs["input"].name)

        [cos] = find_fun_calls(st_split._run_on_acc_0, acc_ops.cos)
        self.assertEqual(arg.name, cos.kwargs["input"].name)

        # Second subgraph calculates d = b + c on CPU
        [add] = find_fun_calls(st_split._run_on_gpu_1, acc_ops.add)
        self.assertEqual(sin.name, add.kwargs["input"].name)
        self.assertEqual(cos.name, add.kwargs["other"].name)
Example #19
0
    def test_get_attr_into_starter_node(self):
        """
        Here we verify the case when starter nodes depend on get_attr node only.
        We don't expect any split to happen in this test, just want to make sure
        that the splitter code doesn't break.
        """

        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.a = torch.randn(2, 3)

            def forward(self):
                m = self.a + self.a
                o = m + m
                return o

        # No need to put anything on ACC.
        class TestOperatorSupport:
            def is_node_supported(self, submodules, node):
                return False

        module_original = acc_tracer.trace(TestModule(), torch.randn(2, 3))

        splitter = TRTSplitter(
            module=module_original,
            sample_input=torch.randn(2, 3),
            operator_support=TestOperatorSupport(),
        )

        def test_splitter(splitter):
            module_split = splitter()
            try:
                verify_split_model(module_split)
            except RuntimeError as err:
                self.assertEqual(
                    str(err), ERROR_MSG_NO_ACC_MODULE
                )

            # Check if modules are equivalent.
            result_original = module_original()
            result_split = module_split()
            self.assertTrue(torch.equal(result_original, result_split))

        test_splitter(splitter)
Example #20
0
    def test_multi_output(self):
        class MultiOutputModule(torch.nn.Module):
            def forward(self, x):
                res, ind = torch.topk(x, 3)
                return torch.sigmoid(res), ind

        mod = acc_tracer.trace(MultiOutputModule(), torch.randn(2, 3))

        # Mark any operation as runnable on ACC
        class CustomOpSupport(op_support.OperatorSupportBase):
            def is_node_supported(self, submodules, node):
                return True

        splitter = TRTSplitter(
            mod, (torch.randn(2, 3),), CustomOpSupport()
        )

        def test_splitter(splitter):
            st_split = splitter()
            verify_split_model(st_split)
            [arg] = find_inputs(st_split)

            # There is only one subgraph that executes topk and sigmoid on ACC
            [topk] = find_fun_calls(st_split._run_on_acc_0, acc_ops.topk)
            self.assertEqual(arg.name, topk.kwargs["input"].name)
            self.assertEqual(3, topk.kwargs["k"])

            [topk_res1, topk_res2] = find_fun_calls(
                st_split._run_on_acc_0, acc_ops.getitem
            )

            [sigmoid] = find_fun_calls(st_split._run_on_acc_0, acc_ops.sigmoid)
            self.assertIn(
                sigmoid.kwargs["input"].name, {topk_res1.name, topk_res2.name}
            )

            # Main graph returns a tuple
            output = find_output(st_split._run_on_acc_0)
            self.assertLess(
                {output.args[0][0].name, output.args[0][1].name},
                {topk_res1.name, topk_res2.name, sigmoid.name},
            )

        test_splitter(splitter)
Example #21
0
    def test_split_complex_graph_2(self):
        module_nn = self.TestModule()
        module = acc_tracer.trace(module_nn, (torch.randn(2, 3),))

        # Making 'c', 'd' and 'e' run on ACC
        splitter = TRTSplitter(
            module,
            (torch.randn(2, 3),),
            op_support_with_support_dict(
                {
                    "acc_ops.cos": None,
                    "acc_ops.relu": None,
                    "acc_ops.add": None,
                }
            ),
        )

        def test_splitter(splitter):
            module_fx_split = splitter()
            verify_split_model(module_fx_split)

            [arg] = find_inputs(module)

            # First subgraph calculates b = sin(a) on CPU
            [sin] = find_fun_calls(module_fx_split._run_on_gpu_0, acc_ops.sin)
            self.assertEqual(arg.name, sin.kwargs["input"].name)

            # Second subgraph calculates c = relu(a), d = cos(a) and e = b + c on ACC
            [relu] = find_fun_calls(module_fx_split._run_on_acc_1, acc_ops.relu)
            self.assertEqual(arg.name, relu.kwargs["input"].name)

            [cos] = find_fun_calls(module_fx_split._run_on_acc_1, acc_ops.cos)
            self.assertEqual(arg.name, cos.kwargs["input"].name)

            [add] = find_fun_calls(module_fx_split._run_on_acc_1, acc_ops.add)
            self.assertEqual(sin.name, add.kwargs["input"].name)
            self.assertEqual(relu.name, add.kwargs["other"].name)

            # Third subgraph calculates f = e + d on CPU
            [sub] = find_fun_calls(module_fx_split._run_on_gpu_2, acc_ops.sub)
            self.assertEqual(add.name, sub.kwargs["input"].name)
            self.assertEqual(cos.name, sub.kwargs["other"].name)

        test_splitter(splitter)
    def test_check_skip_folding_quant_dequant_pattern(self):
        r"""
        Set up skip_folding_quant_dequant function to skip quant/dequant pattern.
        This example shows how to use skip_folding_node_fn.
        """
        class ConstFoldTestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.weight = torch.nn.Parameter(torch.randn(4, 4))
                self.bias = torch.nn.Parameter(torch.randn(4))
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                quant_weight = torch.quantize_per_tensor(
                    self.weight, 0.5, 3, torch.quint8)
                dequant_weight = torch.dequantize(quant_weight)
                output = torch.nn.functional.linear(x, dequant_weight,
                                                    self.bias)
                return self.relu(output)

        mod = ConstFoldTestModule()
        in_x = torch.randn(2, 4)
        gm = acc_tracer.trace(mod, in_x)

        def skip_folding_quant_dequant(node: torch.fx.Node):
            if node.target != acc_ops.quantize_per_tensor:
                return False
            # If quantize_per_node -> dequantize, then skip folding.
            for user in node.users:
                if user.target == acc_ops.dequantize:
                    return True
            return False

        gm_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(
            gm, skip_folding_node_fn=skip_folding_quant_dequant)

        # Check that the folded graph module is None, since there was no folding to do.
        self.assertTrue(gm_folded.const_subgraph_module is None)

        # Now run both folded and non-folded to check results equal.
        fold_result = gm_folded(in_x)
        base_result = mod(in_x)
        self.assertTrue(torch.equal(fold_result, base_result))
Example #23
0
def lower_to_trt(model, inputs, shape_ranges):
    """ Lower a quantized model to TensorRT
    """
    assert len(inputs) == 1, "lower_to_trt only works for one input currently"
    model = acc_tracer.trace(model, inputs)  # type: ignore[attr-defined]
    # TODO: test multiple inputs setting and enable multiple inputs
    input_specs = [
        InputTensorSpec(
            torch.Size([-1, *inputs[0].shape[1:]]), torch.float,
            shape_ranges=shape_ranges, has_batch_dim=True)
    ]

    interp = TRTInterpreter(
        model,
        input_specs,
        explicit_batch_dimension=True, explicit_precision=True)
    result = interp.run(fp16_mode=False, int8_mode=True)
    trt_mod = TRTModule(result.engine, result.input_names, result.output_names)
    return trt_mod
Example #24
0
    def create(
        cls,
        lower_setting: LowerSetting,
        trt_module_observer: Optional[Callable[
            [str, nn.Module, List[torch.Tensor]], None]] = None,
    ) -> "Lowerer":
        """Instantiate a `Lowerer` instance."""

        return cls(
            trace_func=lambda module, inputs: acc_tracer.trace(
                module,
                inputs,  # type: ignore[arg-type]
                ast_rewriter_allow_list=lower_setting.ast_rewriter_allow_list,
                leaf_module_list=lower_setting.leaf_module_list
            ),  # type: ignore[arg-type]
            split_func=default_split_function,
            trt_interpreter=LowerTrtInterpreter.create(lower_setting),
            lower_setting=lower_setting,
            trt_module_observer=trt_module_observer,
        )
Example #25
0
    def test_save_and_load_trt_module(self):
        class TestModule(torch.nn.Module):
            def forward(self, x):
                return x + x

        inputs = [torch.randn(1, 1)]
        mod = TestModule().eval()
        ref_output = mod(*inputs)

        mod = acc_tracer.trace(mod, inputs)
        interp = TRTInterpreter(
            mod, input_specs=InputTensorSpec.from_tensors(inputs))
        trt_mod = TRTModule(*interp.run(fp16_mode=False))
        torch.save(trt_mod, "trt.pt")
        reload_trt_mod = torch.load("trt.pt")

        torch.testing.assert_allclose(reload_trt_mod(inputs[0].cuda()).cpu(),
                                      ref_output,
                                      rtol=1e-04,
                                      atol=1e-04)
Example #26
0
    def run_test_with_assert_error(
        self,
        mod,
        inputs,
        expect_error,
        test_explicit_batch_dim=True,
        test_implicit_batch_dim=True,
    ):
        mod.eval()
        mod = acc_tracer.trace(mod, inputs)

        if test_implicit_batch_dim:
            interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
            super().run_test_with_error(mod, inputs, interp, expect_error)

        if test_explicit_batch_dim:
            interp = TRTInterpreter(mod,
                                    InputTensorSpec.from_tensors(inputs),
                                    explicit_batch_dimension=True)
            super().run_test_with_error(mod, inputs, interp, expect_error)
Example #27
0
    def create(
        cls,
        lower_setting: LowerSetting,
        trt_module_observer: Optional[Callable[
            [str, nn.Module, List[torch.Tensor]], None]] = None
    ) -> "Lowerer":
        """Instantiate a `Lowerer` instance."""

        return Lowerer(
            split=Splitter.create(not lower_setting.explicit_batch_dimension),
            acc_trace=lambda mod, input: acc_tracer.trace(
                mod,
                input,  # type: ignore[arg-type]
                ast_rewriter_allow_list=lower_setting.ast_rewriter_allow_list,
                leaf_module_list=lower_setting.leaf_module_list
            ),  # type: ignore[arg-type]
            remove_duplicate_output_args=remove_duplicate_output_args,
            trt_interpreter=LowerTrtInterpreter.create(lower_setting),
            fp16=lower_setting.fp16_mode,
            trt_module_observer=trt_module_observer,
        )
Example #28
0
def build_int8_trt(rn18):
    rn18 = copy.deepcopy(rn18)
    data = torch.randn(1, 3, 224, 224)
    # data = torch.randn(1, 64, 10, 10)
    # TensorRT only supports symmetric quantization
    qconfig = torch.quantization.QConfig(
        activation=torch.quantization.observer.HistogramObserver.with_args(
            qscheme=torch.per_tensor_symmetric, dtype=torch.qint8),
        weight=torch.quantization.default_weight_observer)
    prepared = prepare_fx(rn18, {"": qconfig})
    for _ in range(10):
        prepared(data)
    quantized_rn18 = convert_fx(prepared, is_reference=True)
    print("quantized model:", quantized_rn18)

    quantized_rn18 = acc_tracer.trace(quantized_rn18,
                                      [data])  # type: ignore[attr-defined]
    interp = TRTInterpreter(
        quantized_rn18,
        [InputTensorSpec(data.shape[1:], torch.float, has_batch_dim=False)])
    engine, input_names, output_names = interp.run(fp16_mode=False,
                                                   int8_mode=True)
    return TRTModule(engine, input_names, output_names)
Example #29
0
    def test_extend_acc_subgraph_after_split(self):
        class TestModule(torch.nn.Module):
            r"""     a (input)
                     |
                     b
                    / \
                   c   d
                    \ /
                     e
                    / \
                   |   (g1, g2, g3, g4)
                    \ / |
                     f  |
                      \ |
                       h

            c and f are not runnable on acc while all other nodes are supported by acc.
            g1, g2, g3 and g4 should be in a fusion group, let's call it g.

            After split we have 2 cpu subgraphs (c) and (f), 3 acc subgraphs (b, d), (e, g) and (h).
            We expect 3 acc subgraphs (b), (d, e, g) and (h) after extend the second acc subgraph.
            And expect acc subgraphs stay the same after extend the third acc subgraph because of
            the unbreakable fusion group.
            """
            def forward(self, a: torch.Tensor):
                b = a + a
                c = b - b
                d = b + b
                e = c + d

                # These four nodes should be in a fusion group
                g1 = e.size()
                g2 = g1[0]
                g3 = e + g2
                g4 = g3 + g2

                f = e - g3
                h = f + g4
                return h

        a = torch.randn(2)
        mod = acc_tracer.trace(TestModule(), (a, ))

        # Allow all nodes expect subtract run on accelerator
        class CustomOpSupport(op_support.OperatorSupportBase):
            def is_node_supported(self, submodules, node):
                return op_support.get_node_target(submodules,
                                                  node) != "acc_ops.sub"

        splitter = TRTSplitter(mod, (a, ), CustomOpSupport())

        def test_splitter(splitter):
            # Manually tag nodes first in case split algorithm changes in the future
            nodes = list(splitter.module.graph.nodes)
            # b and d
            nodes[1].tag = "acc_0"
            nodes[3].tag = "acc_0"
            # c
            nodes[2].tag = "cpu_1"
            # e and g
            nodes[4].tag = "acc_2"
            nodes[5].tag = "acc_2"
            nodes[6].tag = "acc_2"
            nodes[7].tag = "acc_2"
            nodes[8].tag = "acc_2"
            # f
            nodes[9].tag = "cpu_3"
            # h
            nodes[10].tag = "acc_4"

            splitter.tags = ["acc_0", "cpu_1", "acc_2", "cpu_3", "acc_4"]
            split_module = splitter.split()
            try:
                verify_split_model(split_module, "acc_")
            except RuntimeError as err:
                self.assertEqual(str(err), ERROR_MSG_MULTI_ACC_MODULES)
            try:
                verify_split_model(split_module)
            except RuntimeError as err:
                self.assertEqual(str(err), ERROR_MSG_NO_ACC_MODULE)

            module_names = [name for name, _ in split_module.named_modules()]
            # Main module, 2 cpu submodules and 3 acc submodule
            assert len(module_names) == 6

            # 1 Placeholder, 2 Adds and 1 Output
            assert len(split_module.acc_0.graph.nodes) == 4
            # 2 Placeholder, 3 Adds, 1 Size, 1 GetItem and 1 Output
            assert len(split_module.acc_2.graph.nodes) == 8

            # Extend the second acc subgraph
            splitter.extend_acc_subgraph("acc_2")
            extend_module = splitter.split()
            try:
                verify_split_model(extend_module, "acc_")
            except RuntimeError as err:
                self.assertEqual(str(err), ERROR_MSG_MULTI_ACC_MODULES)

            # 1 Placeholder, 1 Adds and 1 Output
            assert len(extend_module.acc_0.graph.nodes) == 3
            # 2 Placeholder, 4 Adds 1 Size, 1 GetItem and 1 Output
            assert len(extend_module.acc_2.graph.nodes) == 9

            # Extend the third acc subgraph
            splitter.extend_acc_subgraph("acc_4")
            extend_module = splitter.split()
            try:
                verify_split_model(extend_module, "acc_")
            except RuntimeError as err:
                self.assertEqual(str(err), ERROR_MSG_MULTI_ACC_MODULES)

            assert len(extend_module.acc_2.graph.nodes) == 9
            # 2 Placeholder, 1 Adds and 1 Output
            assert len(extend_module.acc_4.graph.nodes) == 4

        test_splitter(splitter)
Example #30
0
        super().__init__()
        self.linear = nn.Linear(10, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.linear(x)
        x = self.relu(x)
        x = torch.linalg.norm(x, ord=2, dim=1)
        return x

inputs = [torch.randn(1, 10)]
model = Model().eval()

# acc_tracer is a custom fx tracer that maps nodes whose targets are PyTorch operators
# to acc ops.
traced = acc_tracer.trace(model, inputs)

# Splitter will split the model into serveral submodules. The name of submodules will
# be either `run_on_acc_{}` or `run_on_gpu_{}`. Submodules named `run_on_acc_{}` can
# be fully lowered to TensorRT via fx2trt while submodules named `run_on_gpu_{}` has
# unsupported ops and can't be lowered by fx2trt. We can still run `run_on_gpu_{}`
# submodules on Gpu if ops there have cuda implementation, the naming is a bit
# confusing and we'll improve it.
splitter = TRTSplitter(traced, inputs)

# Preview functionality allows us to see what are the supported ops and unsupported
# ops. We can optionally the dot graph which will color supported ops and unsupported
# ops differently.
splitter.node_support_preview(dump_graph=False)
"""
Supported node types in the model: