示例#1
0
    def test_per_output_tol(self, mode):
        OUT0_NAME = "output0"
        OUT1_NAME = "output1"
        OUT_VALS = np.ones((4, 4))

        iter_result0 = IterationResult(outputs={OUT0_NAME: OUT_VALS, OUT1_NAME: OUT_VALS})
        iter_result1 = IterationResult(outputs={OUT0_NAME: OUT_VALS, OUT1_NAME: OUT_VALS + 1})

        # With default tolerances, out1 is wrong for the second result.
        compare_func = CompareFunc.basic_compare_func()
        acc = compare_func(iter_result0, iter_result1)
        assert acc[OUT0_NAME]
        assert not acc[OUT1_NAME]

        # But with custom tolerances, it should pass.
        tols = {
            OUT0_NAME: 0.0,
            OUT1_NAME: 1.0,
        }

        if mode == "abs":
            compare_func = CompareFunc.basic_compare_func(atol=tols)
        else:
            compare_func = CompareFunc.basic_compare_func(rtol=tols)

        acc = compare_func(iter_result0, iter_result1)
        assert acc[OUT0_NAME]
        assert acc[OUT1_NAME]
示例#2
0
    def test_per_output_tol_fallback(self, mode):
        OUT0_NAME = "output0"
        OUT1_NAME = "output1"
        OUT_VALS = np.ones((4, 4))

        iter_result0 = IterationResult(outputs={OUT0_NAME: OUT_VALS + 1, OUT1_NAME: OUT_VALS})
        iter_result1 = IterationResult(outputs={OUT0_NAME: OUT_VALS, OUT1_NAME: OUT_VALS + 1})

        acc = CompareFunc.basic_compare_func()(iter_result0, iter_result1)
        assert not acc[OUT0_NAME]
        assert not acc[OUT1_NAME]

        # Do not specify tolerance for OUT0_NAME - it should fail with fallback tolerance
        tols = {
            OUT1_NAME: 1.0,
        }

        if mode == "abs":
            compare_func = CompareFunc.basic_compare_func(atol=tols)
        else:
            compare_func = CompareFunc.basic_compare_func(rtol=tols)

        acc = compare_func(iter_result0, iter_result1)
        assert not acc[OUT0_NAME]
        assert acc[OUT1_NAME]
示例#3
0
    def test_atol_rtol_either_pass(self, check_error_stat):
        # If either rtol/atol is sufficient, the compare_func should pass
        res0 = IterationResult(outputs={"output": np.array([1, 2], dtype=np.float32)})
        res1 = IterationResult(outputs={"output": np.array((1.25, 2.5), dtype=np.float32)})

        assert not CompareFunc.basic_compare_func(check_error_stat=check_error_stat)(res0, res1)["output"]

        assert CompareFunc.basic_compare_func(check_error_stat=check_error_stat, rtol=0.25)(res0, res1)["output"]
        assert CompareFunc.basic_compare_func(check_error_stat=check_error_stat, atol=0.5)(res0, res1)["output"]
示例#4
0
    def test_invalid_error_stat(self):
        res0 = IterationResult(
            outputs={"output": np.array([0, 1, 2, 3], dtype=np.float32)})
        res1 = IterationResult(
            outputs={
                "output": np.array((0.15, 1.25, 2.5, 3.75), dtype=np.float32)
            })

        with pytest.raises(PolygraphyException, match="Invalid choice"):
            CompareFunc.basic_compare_func(check_error_stat="invalid-stat")(
                res0, res1)
示例#5
0
    def test_atol_rtol_combined_pass(self):
        # We should also be able to mix them - i.e. rtol might enough for some, atol for others.
        # If they cover the entire output range, it should pass.
        res0 = IterationResult(outputs={"output": np.array([0, 1, 2, 3], dtype=np.float32)})
        res1 = IterationResult(outputs={"output": np.array((0.15, 1.25, 2.5, 3.75), dtype=np.float32)})

        assert not CompareFunc.basic_compare_func()(res0, res1)["output"]

        assert not CompareFunc.basic_compare_func(atol=0.3)(res0, res1)["output"]
        assert not CompareFunc.basic_compare_func(rtol=0.25)(res0, res1)["output"]

        assert CompareFunc.basic_compare_func(atol=0.3, rtol=0.25)(res0, res1)["output"]
示例#6
0
    def test_per_output_error_stat(self, check_error_stat):
        # output0 will only pass when using check_error_stat=mean
        res0 = IterationResult(outputs={
            "output0": np.array([0, 1, 2, 3], dtype=np.float32),
            "output1": np.array([0, 1, 2, 3], dtype=np.float32),
        })
        res1 = IterationResult(outputs={
            "output0": np.array((0.15, 1.25, 2.5, 3.75), dtype=np.float32),
            "output1": np.array((0, 1, 2, 3), dtype=np.float32),
        })

        atol = 0.4125
        assert not CompareFunc.basic_compare_func(atol=atol)(res0, res1)["output0"]

        assert CompareFunc.basic_compare_func(check_error_stat=check_error_stat, atol=atol)(res0, res1)["output0"]
        assert CompareFunc.basic_compare_func(check_error_stat=check_error_stat, atol=atol)(res0, res1)["output1"]
示例#7
0
    def test_can_compare_bool(self):
        iter_result0 = IterationResult(outputs={"output": np.zeros((4, 4), dtype=np.bool)})
        iter_result1 = IterationResult(outputs={"output": np.ones((4, 4), dtype=np.bool)})

        compare_func = CompareFunc.basic_compare_func()
        acc = compare_func(iter_result0, iter_result1)

        assert not acc["output"]
示例#8
0
    def check_network(self, suffix):
        """
        Checks whether the provided network is accurate compared to golden values.

        Returns:
            OrderedDict[str, OutputCompareResult]:
                    A mapping of output names to an object describing whether they matched, and what the
                    required tolerances were.
        """
        from polygraphy.comparator import Comparator, CompareFunc, DataLoader
        from polygraphy.backend.trt import EngineFromNetwork, TrtRunner, ModifyNetwork, SaveEngine

        with G_LOGGER.verbosity(severity=G_LOGGER.severity if self.args.
                                show_output else G_LOGGER.CRITICAL):
            data_loader = tool_util.get_data_loader(self.args)

            self.args.strict_types = True  # HACK: Override strict types so things actually run in the right precision.
            config = tool_util.get_trt_config_loader(self.args,
                                                     data_loader)(self.builder,
                                                                  self.network)

            suffix = "-{:}-{:}".format(suffix, self.precision)
            engine_path = misc.insert_suffix(self.args.save_engine, suffix)

            self.builder, self.network, self.parser = ModifyNetwork(
                (self.builder, self.network, self.parser),
                outputs=self.args.trt_outputs)()

            engine_loader = SaveEngine(EngineFromNetwork(
                (self.builder, self.network, self.parser), config),
                                       path=engine_path)

            runners = [TrtRunner(engine_loader)]

            results = Comparator.run(runners, data_loader=data_loader)
            if self.args.validate:
                Comparator.validate(results)
            results.update(self.golden)

            compare_func = CompareFunc.basic_compare_func(
                atol=self.args.atol,
                rtol=self.args.rtol,
                check_shapes=not self.args.no_shape_check)
            accuracy_result = Comparator.compare_accuracy(
                results, compare_func=compare_func)

        tolerances = list(accuracy_result.values())[0][
            0]  # First iteration of first runner pair
        for name, req_tol in tolerances.items():
            if bool(req_tol):
                G_LOGGER.success(
                    "PASSED | Output: {:} | Required Tolerances: {:}".format(
                        name, req_tol))
            else:
                G_LOGGER.error(
                    "FAILED | Output: {:} | Required Tolerances: {:}".format(
                        name, req_tol))
        return accuracy_result
示例#9
0
    def test_default_tol_in_map(self, mode):
        # "" can be used to indicate a global tolerance
        OUT0_NAME = "output0"
        OUT_VALS = np.ones((4, 4))

        iter_result0 = IterationResult(outputs={OUT0_NAME: OUT_VALS})
        iter_result1 = IterationResult(outputs={OUT0_NAME: OUT_VALS + 1})

        tols = {
            "": 1.0,
        }

        if mode == "abs":
            compare_func = CompareFunc.basic_compare_func(atol=tols)
        else:
            compare_func = CompareFunc.basic_compare_func(rtol=tols)

        acc = compare_func(iter_result0, iter_result1)
        assert acc[OUT0_NAME]
示例#10
0
    def test_non_matching_outputs(self, shape):
        iter_result0 = IterationResult(outputs={"output": np.zeros(shape, dtype=np.float32)})
        iter_result1 = IterationResult(outputs={"output": np.ones(shape, dtype=np.float32)})

        compare_func = CompareFunc.basic_compare_func()

        with G_LOGGER.verbosity(G_LOGGER.ULTRA_VERBOSE):
            acc = compare_func(iter_result0, iter_result1)

        assert util.is_empty_shape(shape) or not acc["output"]
示例#11
0
    def test_check_error_stat(self, func, check_error_stat):
        iter_result0 = IterationResult(outputs={"output": func((100, ), dtype=np.float32)})
        iter_result1 = IterationResult(outputs={"output": func((100, ), dtype=np.float32)})

        iter_result0["output"][0] += 100

        # Even though the max diff is 100, atol=1 should cause this to pass since we're checking
        # against the mean error.
        compare_func = CompareFunc.basic_compare_func(check_error_stat=check_error_stat, atol=1)

        if check_error_stat in ["max", "elemwise"]:
            assert not compare_func(iter_result0, iter_result1)["output"]
        else:
            assert compare_func(iter_result0, iter_result1)["output"]
示例#12
0
    def test_multiple_runners(self):
        load_tf = TF_MODELS["identity"].loader
        build_tf_session = SessionFromGraph(load_tf)
        load_serialized_onnx = BytesFromOnnx(OnnxFromTfGraph(load_tf))
        build_onnxrt_session = SessionFromOnnxBytes(load_serialized_onnx)
        load_engine = EngineFromNetwork(NetworkFromOnnxBytes(load_serialized_onnx))

        runners = [
            TfRunner(build_tf_session),
            OnnxrtRunner(build_onnxrt_session),
            TrtRunner(load_engine),
        ]

        run_results = Comparator.run(runners)
        compare_func = CompareFunc.basic_compare_func(check_shapes=version(trt.__version__) >= version("7.0"))
        assert bool(Comparator.compare_accuracy(run_results, compare_func=compare_func))
        assert len(list(run_results.values())[0]) == 1 # Default number of iterations
示例#13
0
    def test_dim_param_trt_onnxrt(self):
        load_onnx_bytes = ONNX_MODELS["dim_param"].loader
        build_onnxrt_session = SessionFromOnnx(load_onnx_bytes)
        load_engine = EngineFromNetwork(NetworkFromOnnxBytes(load_onnx_bytes))

        runners = [
            OnnxrtRunner(build_onnxrt_session),
            TrtRunner(load_engine),
        ]

        run_results = Comparator.run(runners)
        compare_func = CompareFunc.basic_compare_func(
            check_shapes=mod.version(trt.__version__) >= mod.version("7.0"))
        assert bool(
            Comparator.compare_accuracy(run_results,
                                        compare_func=compare_func))
        assert len(list(
            run_results.values())[0]) == 1  # Default number of iterations
    def test_basic_compare_func(self):
        from polygraphy.comparator import CompareFunc

        CompareFunc.basic_compare_func(atol=1, rtol=1)