示例#1
0
    def run_test(
        self,
        mod,
        inputs,
        expected_ops,
        unexpected_ops=None,
        apply_passes=None,
        test_explicit_batch_dim=True,
        test_implicit_batch_dim=True,
        rtol=1e-03,
        atol=1e-03,
    ):
        mod.eval()
        mod = acc_tracer.trace(mod, inputs)

        if apply_passes is not None:
            for p in apply_passes:
                mod = p(mod)

        if test_implicit_batch_dim:
            interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
            super().run_test(mod, inputs, expected_ops, unexpected_ops, interp,
                             rtol, atol)

        if test_explicit_batch_dim:
            interp = TRTInterpreter(mod,
                                    InputTensorSpec.from_tensors(inputs),
                                    explicit_batch_dimension=True)
            super().run_test(mod, inputs, expected_ops, unexpected_ops, interp,
                             rtol, atol)
示例#2
0
    def run_test(
        self,
        mod,
        inputs,
        expected_ops,
        unexpected_ops=None,
        apply_passes=None,
        test_explicit_batch_dim=True,
        test_implicit_batch_dim=True,
        rtol=1e-03,
        atol=1e-03,
        precision=LowerPrecision.FP32,
    ):
        mod.eval()
        mod = acc_tracer.trace(mod, inputs)

        if apply_passes is not None:
            pass_tracer = chain_passes(*apply_passes)
            mod = pass_tracer(mod, inputs)

        if test_implicit_batch_dim:
            interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
            super().run_test(mod, inputs, expected_ops, unexpected_ops, interp,
                             rtol, atol)

        if test_explicit_batch_dim:
            interp = TRTInterpreter(mod,
                                    InputTensorSpec.from_tensors(inputs),
                                    explicit_batch_dimension=True)
            super().run_test(mod, inputs, expected_ops, unexpected_ops, interp,
                             rtol, atol, precision)
示例#3
0
    def run_test_with_assert_error(
        self,
        mod,
        inputs,
        expect_error,
        test_explicit_batch_dim=True,
        test_implicit_batch_dim=True,
    ):
        mod.eval()
        mod = acc_tracer.trace(mod, inputs)

        if test_implicit_batch_dim:
            interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
            super().run_test_with_error(mod, inputs, interp, expect_error)

        if test_explicit_batch_dim:
            interp = TRTInterpreter(mod,
                                    InputTensorSpec.from_tensors(inputs),
                                    explicit_batch_dimension=True)
            super().run_test_with_error(mod, inputs, interp, expect_error)
示例#4
0
 def run_test_with_dynamic_shape(
     self,
     mod,
     input_specs,
     expected_ops,
     unexpected_ops=None,
     rtol=1e-03,
     atol=1e-03,
 ):
     mod.eval()
     inputs = InputTensorSpec.create_inputs_from_specs(input_specs)
     mod = acc_tracer.trace(mod, inputs)
     interp = TRTInterpreter(mod, input_specs, explicit_batch_dimension=True)
     super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol)
示例#5
0
 def run_test_custom_compare_results(
     self,
     mod,
     inputs,
     expected_ops,
     interpreter,
     comparators: List[Tuple[Callable, List]],
     fp16_mode=False,
 ):
     # interpreter is ignored, we do not need this for Vanilla tests
     # Note this is different from internal version, we need to fix the test case
     # after we refactor the internal callsites to use this file
     mod = torch.fx.symbolic_trace(mod)
     shape_prop.ShapeProp(mod).propagate(*inputs)
     mod = NormalizeArgs(mod).transform()
     interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
     super().run_test_custom_compare_results(
         mod, inputs, expected_ops, interp, comparators, fp16_mode=fp16_mode
     )
示例#6
0
 def run_test(self, mod, inputs, expected_ops, rtol=1e-05, atol=1e-06):
     mod = torch.fx.symbolic_trace(mod)
     shape_prop.ShapeProp(mod).propagate(*inputs)
     mod = NormalizeArgs(mod).transform()
     interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
     super().run_test(mod, inputs, expected_ops, None, interp, rtol, atol)