def test_per_output_tol(self, mode): OUT0_NAME = "output0" OUT1_NAME = "output1" OUT_VALS = np.ones((4, 4)) iter_result0 = IterationResult(outputs={OUT0_NAME: OUT_VALS, OUT1_NAME: OUT_VALS}) iter_result1 = IterationResult(outputs={OUT0_NAME: OUT_VALS, OUT1_NAME: OUT_VALS + 1}) # With default tolerances, out1 is wrong for the second result. compare_func = CompareFunc.basic_compare_func() acc = compare_func(iter_result0, iter_result1) assert acc[OUT0_NAME] assert not acc[OUT1_NAME] # But with custom tolerances, it should pass. tols = { OUT0_NAME: 0.0, OUT1_NAME: 1.0, } if mode == "abs": compare_func = CompareFunc.basic_compare_func(atol=tols) else: compare_func = CompareFunc.basic_compare_func(rtol=tols) acc = compare_func(iter_result0, iter_result1) assert acc[OUT0_NAME] assert acc[OUT1_NAME]
def test_per_output_tol_fallback(self, mode): OUT0_NAME = "output0" OUT1_NAME = "output1" OUT_VALS = np.ones((4, 4)) iter_result0 = IterationResult(outputs={OUT0_NAME: OUT_VALS + 1, OUT1_NAME: OUT_VALS}) iter_result1 = IterationResult(outputs={OUT0_NAME: OUT_VALS, OUT1_NAME: OUT_VALS + 1}) acc = CompareFunc.basic_compare_func()(iter_result0, iter_result1) assert not acc[OUT0_NAME] assert not acc[OUT1_NAME] # Do not specify tolerance for OUT0_NAME - it should fail with fallback tolerance tols = { OUT1_NAME: 1.0, } if mode == "abs": compare_func = CompareFunc.basic_compare_func(atol=tols) else: compare_func = CompareFunc.basic_compare_func(rtol=tols) acc = compare_func(iter_result0, iter_result1) assert not acc[OUT0_NAME] assert acc[OUT1_NAME]
def test_atol_rtol_either_pass(self, check_error_stat): # If either rtol/atol is sufficient, the compare_func should pass res0 = IterationResult(outputs={"output": np.array([1, 2], dtype=np.float32)}) res1 = IterationResult(outputs={"output": np.array((1.25, 2.5), dtype=np.float32)}) assert not CompareFunc.basic_compare_func(check_error_stat=check_error_stat)(res0, res1)["output"] assert CompareFunc.basic_compare_func(check_error_stat=check_error_stat, rtol=0.25)(res0, res1)["output"] assert CompareFunc.basic_compare_func(check_error_stat=check_error_stat, atol=0.5)(res0, res1)["output"]
def test_invalid_error_stat(self): res0 = IterationResult( outputs={"output": np.array([0, 1, 2, 3], dtype=np.float32)}) res1 = IterationResult( outputs={ "output": np.array((0.15, 1.25, 2.5, 3.75), dtype=np.float32) }) with pytest.raises(PolygraphyException, match="Invalid choice"): CompareFunc.basic_compare_func(check_error_stat="invalid-stat")( res0, res1)
def test_atol_rtol_combined_pass(self): # We should also be able to mix them - i.e. rtol might enough for some, atol for others. # If they cover the entire output range, it should pass. res0 = IterationResult(outputs={"output": np.array([0, 1, 2, 3], dtype=np.float32)}) res1 = IterationResult(outputs={"output": np.array((0.15, 1.25, 2.5, 3.75), dtype=np.float32)}) assert not CompareFunc.basic_compare_func()(res0, res1)["output"] assert not CompareFunc.basic_compare_func(atol=0.3)(res0, res1)["output"] assert not CompareFunc.basic_compare_func(rtol=0.25)(res0, res1)["output"] assert CompareFunc.basic_compare_func(atol=0.3, rtol=0.25)(res0, res1)["output"]
def test_per_output_error_stat(self, check_error_stat): # output0 will only pass when using check_error_stat=mean res0 = IterationResult(outputs={ "output0": np.array([0, 1, 2, 3], dtype=np.float32), "output1": np.array([0, 1, 2, 3], dtype=np.float32), }) res1 = IterationResult(outputs={ "output0": np.array((0.15, 1.25, 2.5, 3.75), dtype=np.float32), "output1": np.array((0, 1, 2, 3), dtype=np.float32), }) atol = 0.4125 assert not CompareFunc.basic_compare_func(atol=atol)(res0, res1)["output0"] assert CompareFunc.basic_compare_func(check_error_stat=check_error_stat, atol=atol)(res0, res1)["output0"] assert CompareFunc.basic_compare_func(check_error_stat=check_error_stat, atol=atol)(res0, res1)["output1"]
def test_can_compare_bool(self): iter_result0 = IterationResult(outputs={"output": np.zeros((4, 4), dtype=np.bool)}) iter_result1 = IterationResult(outputs={"output": np.ones((4, 4), dtype=np.bool)}) compare_func = CompareFunc.basic_compare_func() acc = compare_func(iter_result0, iter_result1) assert not acc["output"]
def check_network(self, suffix): """ Checks whether the provided network is accurate compared to golden values. Returns: OrderedDict[str, OutputCompareResult]: A mapping of output names to an object describing whether they matched, and what the required tolerances were. """ from polygraphy.comparator import Comparator, CompareFunc, DataLoader from polygraphy.backend.trt import EngineFromNetwork, TrtRunner, ModifyNetwork, SaveEngine with G_LOGGER.verbosity(severity=G_LOGGER.severity if self.args. show_output else G_LOGGER.CRITICAL): data_loader = tool_util.get_data_loader(self.args) self.args.strict_types = True # HACK: Override strict types so things actually run in the right precision. config = tool_util.get_trt_config_loader(self.args, data_loader)(self.builder, self.network) suffix = "-{:}-{:}".format(suffix, self.precision) engine_path = misc.insert_suffix(self.args.save_engine, suffix) self.builder, self.network, self.parser = ModifyNetwork( (self.builder, self.network, self.parser), outputs=self.args.trt_outputs)() engine_loader = SaveEngine(EngineFromNetwork( (self.builder, self.network, self.parser), config), path=engine_path) runners = [TrtRunner(engine_loader)] results = Comparator.run(runners, data_loader=data_loader) if self.args.validate: Comparator.validate(results) results.update(self.golden) compare_func = CompareFunc.basic_compare_func( atol=self.args.atol, rtol=self.args.rtol, check_shapes=not self.args.no_shape_check) accuracy_result = Comparator.compare_accuracy( results, compare_func=compare_func) tolerances = list(accuracy_result.values())[0][ 0] # First iteration of first runner pair for name, req_tol in tolerances.items(): if bool(req_tol): G_LOGGER.success( "PASSED | Output: {:} | Required Tolerances: {:}".format( name, req_tol)) else: G_LOGGER.error( "FAILED | Output: {:} | Required Tolerances: {:}".format( name, req_tol)) return accuracy_result
def test_default_tol_in_map(self, mode): # "" can be used to indicate a global tolerance OUT0_NAME = "output0" OUT_VALS = np.ones((4, 4)) iter_result0 = IterationResult(outputs={OUT0_NAME: OUT_VALS}) iter_result1 = IterationResult(outputs={OUT0_NAME: OUT_VALS + 1}) tols = { "": 1.0, } if mode == "abs": compare_func = CompareFunc.basic_compare_func(atol=tols) else: compare_func = CompareFunc.basic_compare_func(rtol=tols) acc = compare_func(iter_result0, iter_result1) assert acc[OUT0_NAME]
def test_non_matching_outputs(self, shape): iter_result0 = IterationResult(outputs={"output": np.zeros(shape, dtype=np.float32)}) iter_result1 = IterationResult(outputs={"output": np.ones(shape, dtype=np.float32)}) compare_func = CompareFunc.basic_compare_func() with G_LOGGER.verbosity(G_LOGGER.ULTRA_VERBOSE): acc = compare_func(iter_result0, iter_result1) assert util.is_empty_shape(shape) or not acc["output"]
def test_check_error_stat(self, func, check_error_stat): iter_result0 = IterationResult(outputs={"output": func((100, ), dtype=np.float32)}) iter_result1 = IterationResult(outputs={"output": func((100, ), dtype=np.float32)}) iter_result0["output"][0] += 100 # Even though the max diff is 100, atol=1 should cause this to pass since we're checking # against the mean error. compare_func = CompareFunc.basic_compare_func(check_error_stat=check_error_stat, atol=1) if check_error_stat in ["max", "elemwise"]: assert not compare_func(iter_result0, iter_result1)["output"] else: assert compare_func(iter_result0, iter_result1)["output"]
def test_multiple_runners(self): load_tf = TF_MODELS["identity"].loader build_tf_session = SessionFromGraph(load_tf) load_serialized_onnx = BytesFromOnnx(OnnxFromTfGraph(load_tf)) build_onnxrt_session = SessionFromOnnxBytes(load_serialized_onnx) load_engine = EngineFromNetwork(NetworkFromOnnxBytes(load_serialized_onnx)) runners = [ TfRunner(build_tf_session), OnnxrtRunner(build_onnxrt_session), TrtRunner(load_engine), ] run_results = Comparator.run(runners) compare_func = CompareFunc.basic_compare_func(check_shapes=version(trt.__version__) >= version("7.0")) assert bool(Comparator.compare_accuracy(run_results, compare_func=compare_func)) assert len(list(run_results.values())[0]) == 1 # Default number of iterations
def test_dim_param_trt_onnxrt(self): load_onnx_bytes = ONNX_MODELS["dim_param"].loader build_onnxrt_session = SessionFromOnnx(load_onnx_bytes) load_engine = EngineFromNetwork(NetworkFromOnnxBytes(load_onnx_bytes)) runners = [ OnnxrtRunner(build_onnxrt_session), TrtRunner(load_engine), ] run_results = Comparator.run(runners) compare_func = CompareFunc.basic_compare_func( check_shapes=mod.version(trt.__version__) >= mod.version("7.0")) assert bool( Comparator.compare_accuracy(run_results, compare_func=compare_func)) assert len(list( run_results.values())[0]) == 1 # Default number of iterations
def test_basic_compare_func(self): from polygraphy.comparator import CompareFunc CompareFunc.basic_compare_func(atol=1, rtol=1)