def check_network(self, suffix): """ Checks whether the provided network is accurate compared to golden values. Returns: OrderedDict[str, OutputCompareResult]: A mapping of output names to an object describing whether they matched, and what the required tolerances were. """ from polygraphy.comparator import Comparator, CompareFunc, DataLoader from polygraphy.backend.trt import EngineFromNetwork, TrtRunner, ModifyNetwork, SaveEngine with G_LOGGER.verbosity(severity=G_LOGGER.severity if self.args. show_output else G_LOGGER.CRITICAL): data_loader = tool_util.get_data_loader(self.args) self.args.strict_types = True # HACK: Override strict types so things actually run in the right precision. config = tool_util.get_trt_config_loader(self.args, data_loader)(self.builder, self.network) suffix = "-{:}-{:}".format(suffix, self.precision) engine_path = misc.insert_suffix(self.args.save_engine, suffix) self.builder, self.network, self.parser = ModifyNetwork( (self.builder, self.network, self.parser), outputs=self.args.trt_outputs)() engine_loader = SaveEngine(EngineFromNetwork( (self.builder, self.network, self.parser), config), path=engine_path) runners = [TrtRunner(engine_loader)] results = Comparator.run(runners, data_loader=data_loader) if self.args.validate: Comparator.validate(results) results.update(self.golden) compare_func = CompareFunc.basic_compare_func( atol=self.args.atol, rtol=self.args.rtol, check_shapes=not self.args.no_shape_check) accuracy_result = Comparator.compare_accuracy( results, compare_func=compare_func) tolerances = list(accuracy_result.values())[0][ 0] # First iteration of first runner pair for name, req_tol in tolerances.items(): if bool(req_tol): G_LOGGER.success( "PASSED | Output: {:} | Required Tolerances: {:}".format( name, req_tol)) else: G_LOGGER.error( "FAILED | Output: {:} | Required Tolerances: {:}".format( name, req_tol)) return accuracy_result
def test_tf_save_timeline(self): with tempfile.NamedTemporaryFile() as outpath: run_polygraphy_run([ TF_MODELS["identity"].path, "--tf", "--gpu-memory-fraction=0.5", "--save-timeline", outpath.name ]) timelines = glob.glob(misc.insert_suffix(outpath.name, "*")) for timeline in timelines: check_file_non_empty(timeline)