def find(self): which_layers = {"forward": "first", "reverse": "last"}[self.args.mode] num_layers = 0 # Keep track of what works and what doesn't known_good = self.network.num_layers + 1 known_bad = 0 indices = None while known_good != known_bad and num_layers != known_good: with G_LOGGER.indent(): G_LOGGER.info( "Last known good: {which_layers} {known_good} layer(s) in {precision} precision.\n" "Last known bad: {which_layers} {known_bad} layer(s) in {precision} precision" .format(which_layers=which_layers, known_good=min(known_good, self.network.num_layers), precision=self.precision, known_bad=known_bad)) indices = self.layer_indices(num_layers) self.mark_layers(indices) success = self.check_network("{:}-{:}".format( which_layers, num_layers)) if success: # Try something between known_good = num_layers else: known_bad = num_layers # Try something in between the known good value, and the known bad value. num_layers = math.ceil((known_bad + known_good) / 2.0) if known_good <= self.network.num_layers: return indices
def log_output_stats(output, info_hist=False, runner_name=None, hist_range=None): ret = str_output_stats(output, runner_name) G_LOGGER.info(ret) with G_LOGGER.indent(): # Show histogram on failures. G_LOGGER.log(lambda: str_histogram(output, hist_range), severity=G_LOGGER.INFO if info_hist else G_LOGGER.VERBOSE)
def validate_output(runner_name, output_name, output): G_LOGGER.start("{:35} | Validating output: {:} (check_inf={:}, check_nan={:})".format( runner_name, output_name, check_inf, check_nan)) with G_LOGGER.indent(): comp_util.log_output_stats(output) output_valid = True if check_nan: output_valid &= is_not_nan(output) if check_inf: output_valid &= is_finite(output) if output_valid: G_LOGGER.finish("PASSED | Output: {:} is valid".format(output_name)) else: G_LOGGER.error("FAILED | Errors detected in output: {:}".format(output_name)) return output_valid
def sort_artifacts(self, iteration, suffix=None): """ Run the check command and move artifacts into the correct subdirectory. Args: iteration (int): The current iteration index. This is used to name artifacts and display logging messages. suffix (str): A custom suffix to add to the artifact prior to moving it. This will be applied in addition to the default suffix. Returns: bool: True if the command succeeded, False otherwise. """ def move_artifacts(subdir, returncode): """ Moves artifacts (args.artifacts) into the specified subdirectory or args.output and appends an index and timestamp. Creates parent directories as required. Args: subdir (str): The destination path as a subdirectory of args.output. index (int): The iteration index. """ for art in self.artifacts: basename, ext = os.path.splitext(os.path.basename(art)) if suffix: basename += suffix name = "{:}_{:}_{:}_N{:}_ret{:}{:}".format( basename, self.start_date, self.start_time, iteration, returncode, ext) dest = os.path.join(self.output, subdir, name) if not os.path.exists(art): G_LOGGER.error( "Artifact: {:} does not exist, skipping.\n" "Was the artifact supposed to be generated?".format( art)) continue if os.path.exists(dest): G_LOGGER.error( "Destination path: {:} already exists.\n" "Refusing to overwrite. This artifact will be skipped!" .format(dest)) continue G_LOGGER.info("Moving {:} to {:}".format(art, dest)) dir_path = os.path.dirname(dest) if dir_path: dir_path = os.path.realpath(dir_path) os.makedirs(dir_path, exist_ok=True) shutil.move(art, dest) def try_remove(path): def func(): try: os.remove(path) except: G_LOGGER.verbose("Could not remove: {:}".format(path)) return func def is_success(status): has_fail_regex = None if self.fail_regexes is not None: output = status.stdout.decode() + status.stderr.decode() has_fail_regex = any( regex.search(output) is not None for regex in self.fail_regexes) if self.fail_codes is not None: # If a fail-code is specified, then we should also check has_fail_regex if provided. failed = status.returncode in self.fail_codes if has_fail_regex is not None: failed &= has_fail_regex else: # If a fail-code is not specified, we should trigger failures even on 0-status # if the fail regex is found. failed = status.returncode != 0 if has_fail_regex is None else has_fail_regex return not failed with contextlib.ExitStack() as stack, G_LOGGER.indent(): if self.iter_artifact and self.remove_intermediate: stack.callback(try_remove(self.iter_artifact)) if self.iteration_info: util.save_json({"iteration": iteration}, self.iteration_info) stack.callback(try_remove(self.iteration_info)) G_LOGGER.info("Running check command: {:}".format(" ".join( self.check))) status = sp.run(self.check, stdout=sp.PIPE, stderr=sp.PIPE) success = is_success(status) if self.show_output: stderr_log_level = G_LOGGER.WARNING if success else G_LOGGER.ERROR G_LOGGER.info( "========== CAPTURED STDOUT ==========\n{:}".format( status.stdout.decode())) G_LOGGER.log( "========== CAPTURED STDERR ==========\n{:}".format( status.stderr.decode()), severity=stderr_log_level) if success: move_artifacts("good", status.returncode) G_LOGGER.finish("PASSED | Iteration {:}".format(iteration)) return True else: move_artifacts("bad", status.returncode) G_LOGGER.error("FAILED | Iteration {:}".format(iteration)) return False
def check_outputs_match(out0, out0_name, out1, out1_name, per_out_rtol, per_out_atol, per_out_err_stat): VALID_CHECK_ERROR_STATS = ["max", "mean", "median", "elemwise"] if per_out_err_stat not in VALID_CHECK_ERROR_STATS: G_LOGGER.critical("Invalid choice for check_error_stat: {:}.\n" "Note: Valid choices are: {:}".format(per_out_err_stat, VALID_CHECK_ERROR_STATS)) G_LOGGER.super_verbose("{:35} | Output: {:} (dtype={:}, shape={:}):\n{:}".format( iter_result0.runner_name, out0_name, out0.dtype, out0.shape, util.indent_block(out0))) G_LOGGER.super_verbose("{:35} | Output: {:} (dtype={:}, shape={:}):\n{:}".format( iter_result1.runner_name, out1_name, out1.dtype, out1.shape, util.indent_block(out1))) # Check difference vs. tolerances if np.issubdtype(out0.dtype, np.bool_) and np.issubdtype(out1.dtype, np.bool_): absdiff = np.logical_xor(out0, out1) else: absdiff = np.abs(out0 - out1) absout1 = np.abs(out1) with np.testing.suppress_warnings() as sup: sup.filter(RuntimeWarning) reldiff = absdiff / absout1 max_absdiff = comp_util.compute_max(absdiff) mean_absdiff = comp_util.compute_mean(absdiff) median_absdiff = comp_util.compute_median(absdiff) max_reldiff = comp_util.compute_max(reldiff) mean_reldiff = comp_util.compute_mean(reldiff) median_reldiff = comp_util.compute_median(reldiff) max_elemwiseabs = "Unknown" max_elemwiserel = "Unknown" if per_out_err_stat == "mean": failed = mean_absdiff > per_out_atol and (np.isnan(mean_reldiff) or mean_reldiff > per_out_rtol) elif per_out_err_stat == "median": failed = median_absdiff > per_out_atol and (np.isnan(median_reldiff) or median_reldiff > per_out_rtol) elif per_out_err_stat == "max": failed = max_absdiff > per_out_atol and (np.isnan(max_reldiff) or max_reldiff > per_out_rtol) else: assert per_out_err_stat == "elemwise", "This branch should be unreachable unless per_out_err_stat is 'elemwise'" mismatches = (absdiff > per_out_atol) & (reldiff > per_out_rtol) failed = np.any(mismatches) try: # Special because we need to account for tolerances too. max_elemwiseabs = comp_util.compute_max(absdiff[mismatches]) max_elemwiserel = comp_util.compute_max(reldiff[mismatches]) with G_LOGGER.indent(): G_LOGGER.super_verbose("Mismatched indices:\n{:}".format(np.argwhere(mismatches))) G_LOGGER.extra_verbose("{:35} | Mismatched values:\n{:}".format(iter_result0.runner_name, out0[mismatches])) G_LOGGER.extra_verbose("{:35} | Mismatched values:\n{:}".format(iter_result1.runner_name, out1[mismatches])) except Exception as err: G_LOGGER.warning("Failing to log mismatches.\nNote: Error was: {:}".format(err)) # Log information about the outputs hist_bin_range = (min(comp_util.compute_min(out0), comp_util.compute_min(out1)), max(comp_util.compute_max(out0), comp_util.compute_max(out1))) comp_util.log_output_stats(out0, failed, iter_result0.runner_name + ": " + out0_name, hist_range=hist_bin_range) comp_util.log_output_stats(out1, failed, iter_result1.runner_name + ": " + out1_name, hist_range=hist_bin_range) G_LOGGER.info("Error Metrics: {:}".format(out0_name)) with G_LOGGER.indent(): def req_tol(mean_diff, median_diff, max_diff, elemwise_diff): return { "mean": mean_diff, "median": median_diff, "max": max_diff, "elemwise": elemwise_diff, }[per_out_err_stat] G_LOGGER.info("Minimum Required Tolerance: {:} error | [abs={:.5g}] OR [rel={:.5g}]".format( per_out_err_stat, req_tol(mean_absdiff, median_absdiff, max_absdiff, max_elemwiseabs), req_tol(mean_reldiff, median_reldiff, max_reldiff, max_elemwiserel))) comp_util.log_output_stats(absdiff, failed, "Absolute Difference") comp_util.log_output_stats(reldiff, failed, "Relative Difference") # Finally show summary. if failed: G_LOGGER.error("FAILED | Difference exceeds tolerance (rel={:}, abs={:})".format(per_out_rtol, per_out_atol)) else: G_LOGGER.finish("PASSED | Difference is within tolerance (rel={:}, abs={:})".format(per_out_rtol, per_out_atol)) G_LOGGER.extra_verbose("Finished comparing: '{:}' (dtype={:}, shape={:}) [{:}] and '{:}' (dtype={:}, shape={:}) [{:}]" .format(out0_name, out0.dtype, out0.shape, iter_result0.runner_name, out1_name, out1.dtype, out1.shape, iter_result1.runner_name)) return OutputCompareResult(not failed, max_absdiff, max_reldiff, mean_absdiff, mean_reldiff, median_absdiff, median_reldiff)
def compare_output(iter_result0, iter_result1): """ Compare the outputs of two runners from a single iteration. This function will always iterate over the output names of the first IterationResult, and attempt to find corresponding output names in the second. If no corresponding output name is found, the output is skipped. If all output names are skipped, then this function raises an error. Args: iter_result0 (IterationResult): The result of the first runner. iter_result1 (IterationResult): The result of the second runner. Returns: OrderedDict[str, OutputCompareResult]: The name of the outputs compared, derived from the first IterationResult, and whether they matched. If an output name is not found, it is omitted from this dictionary. Raises: PolygraphyException: If all output names are skipped, and thus no outputs are compared. """ def check_dict(dct, dict_name): if isinstance(dct, dict): util.check_dict_contains(dct, set(iter_result0.keys()) | set(iter_result1.keys()) | set([""]), check_missing=False, dict_name=dict_name) check_dict(rtol, "the rtol dictionary") check_dict(atol, "the atol dictionary") check_dict(check_error_stat, "the chcek_error_stat dictionary") # Returns whether the outputs match def check_outputs_match(out0, out0_name, out1, out1_name, per_out_rtol, per_out_atol, per_out_err_stat): VALID_CHECK_ERROR_STATS = ["max", "mean", "median", "elemwise"] if per_out_err_stat not in VALID_CHECK_ERROR_STATS: G_LOGGER.critical("Invalid choice for check_error_stat: {:}.\n" "Note: Valid choices are: {:}".format(per_out_err_stat, VALID_CHECK_ERROR_STATS)) G_LOGGER.super_verbose("{:35} | Output: {:} (dtype={:}, shape={:}):\n{:}".format( iter_result0.runner_name, out0_name, out0.dtype, out0.shape, util.indent_block(out0))) G_LOGGER.super_verbose("{:35} | Output: {:} (dtype={:}, shape={:}):\n{:}".format( iter_result1.runner_name, out1_name, out1.dtype, out1.shape, util.indent_block(out1))) # Check difference vs. tolerances if np.issubdtype(out0.dtype, np.bool_) and np.issubdtype(out1.dtype, np.bool_): absdiff = np.logical_xor(out0, out1) else: absdiff = np.abs(out0 - out1) absout1 = np.abs(out1) with np.testing.suppress_warnings() as sup: sup.filter(RuntimeWarning) reldiff = absdiff / absout1 max_absdiff = comp_util.compute_max(absdiff) mean_absdiff = comp_util.compute_mean(absdiff) median_absdiff = comp_util.compute_median(absdiff) max_reldiff = comp_util.compute_max(reldiff) mean_reldiff = comp_util.compute_mean(reldiff) median_reldiff = comp_util.compute_median(reldiff) max_elemwiseabs = "Unknown" max_elemwiserel = "Unknown" if per_out_err_stat == "mean": failed = mean_absdiff > per_out_atol and (np.isnan(mean_reldiff) or mean_reldiff > per_out_rtol) elif per_out_err_stat == "median": failed = median_absdiff > per_out_atol and (np.isnan(median_reldiff) or median_reldiff > per_out_rtol) elif per_out_err_stat == "max": failed = max_absdiff > per_out_atol and (np.isnan(max_reldiff) or max_reldiff > per_out_rtol) else: assert per_out_err_stat == "elemwise", "This branch should be unreachable unless per_out_err_stat is 'elemwise'" mismatches = (absdiff > per_out_atol) & (reldiff > per_out_rtol) failed = np.any(mismatches) try: # Special because we need to account for tolerances too. max_elemwiseabs = comp_util.compute_max(absdiff[mismatches]) max_elemwiserel = comp_util.compute_max(reldiff[mismatches]) with G_LOGGER.indent(): G_LOGGER.super_verbose("Mismatched indices:\n{:}".format(np.argwhere(mismatches))) G_LOGGER.extra_verbose("{:35} | Mismatched values:\n{:}".format(iter_result0.runner_name, out0[mismatches])) G_LOGGER.extra_verbose("{:35} | Mismatched values:\n{:}".format(iter_result1.runner_name, out1[mismatches])) except Exception as err: G_LOGGER.warning("Failing to log mismatches.\nNote: Error was: {:}".format(err)) # Log information about the outputs hist_bin_range = (min(comp_util.compute_min(out0), comp_util.compute_min(out1)), max(comp_util.compute_max(out0), comp_util.compute_max(out1))) comp_util.log_output_stats(out0, failed, iter_result0.runner_name + ": " + out0_name, hist_range=hist_bin_range) comp_util.log_output_stats(out1, failed, iter_result1.runner_name + ": " + out1_name, hist_range=hist_bin_range) G_LOGGER.info("Error Metrics: {:}".format(out0_name)) with G_LOGGER.indent(): def req_tol(mean_diff, median_diff, max_diff, elemwise_diff): return { "mean": mean_diff, "median": median_diff, "max": max_diff, "elemwise": elemwise_diff, }[per_out_err_stat] G_LOGGER.info("Minimum Required Tolerance: {:} error | [abs={:.5g}] OR [rel={:.5g}]".format( per_out_err_stat, req_tol(mean_absdiff, median_absdiff, max_absdiff, max_elemwiseabs), req_tol(mean_reldiff, median_reldiff, max_reldiff, max_elemwiserel))) comp_util.log_output_stats(absdiff, failed, "Absolute Difference") comp_util.log_output_stats(reldiff, failed, "Relative Difference") # Finally show summary. if failed: G_LOGGER.error("FAILED | Difference exceeds tolerance (rel={:}, abs={:})".format(per_out_rtol, per_out_atol)) else: G_LOGGER.finish("PASSED | Difference is within tolerance (rel={:}, abs={:})".format(per_out_rtol, per_out_atol)) G_LOGGER.extra_verbose("Finished comparing: '{:}' (dtype={:}, shape={:}) [{:}] and '{:}' (dtype={:}, shape={:}) [{:}]" .format(out0_name, out0.dtype, out0.shape, iter_result0.runner_name, out1_name, out1.dtype, out1.shape, iter_result1.runner_name)) return OutputCompareResult(not failed, max_absdiff, max_reldiff, mean_absdiff, mean_reldiff, median_absdiff, median_reldiff) # # End: def check_outputs_match # output_status = OrderedDict() # OrderedDict[str, bool] Maps output names to whether they matched. if not check_shapes: G_LOGGER.info("Strict shape checking disabled. Will attempt to match output shapes before comparisons") def default_find_output_func(output_name, index, iter_result): found_name = util.find_in_dict(output_name, iter_result, index) if found_name is None: return None elif found_name != output_name: exact_match = util.find_in_dict(found_name, iter_result0) if exact_match == found_name: G_LOGGER.verbose("Will not compare {:} with {:}, since the former already has an exact match: {:}".format( found_name, output_name, exact_match)) return None # If the found output is being compared against another output already, skip this non-exact match G_LOGGER.warning("Output names did not match exactly. Assuming {:} output: {:} " "corresponds to output: {:}".format( iter_result.runner_name, found_name, output_name)) return [found_name] nonlocal find_output_func find_output_func = util.default(find_output_func, default_find_output_func) for index, (out0_name, output0) in enumerate(iter_result0.items()): out1_names = util.default(find_output_func(out0_name, index, iter_result1), []) if len(out1_names) > 1: G_LOGGER.info("Will attempt to compare output: '{:}' [{:}] with multiple outputs: '{:}' [{:}]".format( out0_name, iter_result0.runner_name, list(out1_names), iter_result1.runner_name)) for out1_name in out1_names: if out1_name is None or out1_name not in iter_result1: G_LOGGER.warning("For output: '{:}' [{:}], skipping corresponding output: '{:}' [{:}], " "since the output was not found".format(out0_name, iter_result0.runner_name, out1_name, iter_result1.runner_name)) continue def get_tol(tol_dict, default): if isinstance(tol_dict, numbers.Number): return tol_dict if out0_name in tol_dict: return tol_dict[out0_name] elif "" in tol_dict: return tol_dict[""] return default def get_error_stat(): if isinstance(check_error_stat, str): return check_error_stat if out0_name in check_error_stat: return check_error_stat[out0_name] elif "" in check_error_stat: return check_error_stat[""] return default_error_stat per_out_atol = get_tol(atol, default_atol) per_out_rtol = get_tol(rtol, default_rtol) per_out_err_stat = get_error_stat() output1 = iter_result1[out1_name] G_LOGGER.start("Comparing Output: '{:}' (dtype={:}, shape={:}) with '{:}' (dtype={:}, shape={:}) | " "Tolerance: [abs={:.5g}, rel={:.5g}] | Checking {:} error".format( out0_name, output0.dtype, output0.shape, out1_name, output1.dtype, output1.shape, per_out_atol, per_out_rtol, per_out_err_stat)) G_LOGGER.extra_verbose("Note: Comparing {:} vs. {:}".format(iter_result0.runner_name, iter_result1.runner_name)) with G_LOGGER.indent(): if check_shapes and output0.shape != output1.shape: G_LOGGER.error("Will not compare outputs of different shapes. Note: Output shapes are " "{:} and {:}.".format(output0.shape, output1.shape)) G_LOGGER.error("Note: Use --no-strict-shape-checking or set check_shapes=False to " "attempt to compare values anyway.", mode=LogMode.ONCE) outputs_match = False else: output1 = util.try_match_shape(output1, output0.shape) output0 = output0.reshape(output1.shape) outputs_match = check_outputs_match(output0, out0_name, output1, out1_name, per_out_rtol=per_out_rtol, per_out_atol=per_out_atol, per_out_err_stat=per_out_err_stat) output_status[out0_name] = outputs_match if fail_fast and not outputs_match: return output_status mismatched_output_names = [name for name, matched in output_status.items() if not matched] if mismatched_output_names: G_LOGGER.error("FAILED | Mismatched outputs: {:}".format(mismatched_output_names)) else: G_LOGGER.finish("PASSED | All outputs matched | Outputs: {:}".format(list(output_status.keys()))) # This is useful for catching cases were Polygraphy does something wrong with the runner output buffers if not output_status and (bool(iter_result0.keys()) or bool(iter_result1.keys())): r0_name = iter_result0.runner_name r0_outs = list(iter_result0.keys()) r1_name = iter_result1.runner_name r1_outs = list(iter_result1.keys()) G_LOGGER.critical("All outputs were skipped, no common outputs found! Note:\n{:} outputs: " "{:}\n{:} outputs: {:}".format(r0_name, r0_outs, r1_name, r1_outs)) return output_status
def call_impl(self, builder, network): """ Args: builder (trt.Builder): The TensorRT builder to use to create the configuration. network (trt.INetworkDefinition): The TensorRT network for which to create the config. The network is used to automatically create a default optimization profile if none are provided. Returns: trt.IBuilderConfig: The TensorRT builder configuration. """ with util.FreeOnException([builder.create_builder_config()]) as (config, ): def try_run(func, name): try: return func() except AttributeError: trt_util.fail_unavailable("{:} in CreateConfig".format(name)) def try_set_flag(flag_name): return try_run(lambda: config.set_flag(getattr(trt.BuilderFlag, flag_name)), flag_name.lower()) with G_LOGGER.indent(): G_LOGGER.verbose("Setting TensorRT Optimization Profiles") profiles = copy.deepcopy(self.profiles) for profile in profiles: # Last trt_profile is used for set_calibration_profile. trt_profile = profile.fill_defaults(network).to_trt(builder, network) config.add_optimization_profile(trt_profile) G_LOGGER.info("Configuring with profiles: {:}".format(profiles)) config.max_workspace_size = int(self.max_workspace_size) if self.strict_types: try_set_flag("STRICT_TYPES") if self.tf32: try_set_flag("TF32") else: # TF32 is on by default with contextlib.suppress(AttributeError): config.clear_flag(trt.BuilderFlag.TF32) if self.fp16: try_set_flag("FP16") if self.int8: try_set_flag("INT8") if not network.has_explicit_precision: if self.calibrator is not None: input_metadata = trt_util.get_input_metadata_from_profile(trt_profile, network) with contextlib.suppress(AttributeError): # Polygraphy calibrator has a reset method self.calibrator.reset(input_metadata) config.int8_calibrator = self.calibrator try: config.set_calibration_profile(trt_profile) except: G_LOGGER.extra_verbose("Cannot set calibration profile on TensorRT 7.0 and older.") else: G_LOGGER.warning("Network does not have explicit precision and no calibrator was provided. Please ensure " "that tensors in the network have dynamic ranges set, or provide a calibrator in order to use int8 mode.") if self.sparse_weights: try_set_flag("SPARSE_WEIGHTS") if self.tactic_sources is not None: tactic_sources_flag = 0 for source in self.tactic_sources: tactic_sources_flag |= (1 << int(source)) try_run(lambda: config.set_tactic_sources(tactic_sources_flag), name="tactic_sources") try: if self.timing_cache_path: timing_cache_data = util.load_file(self.timing_cache_path, description="tactic timing cache") cache = config.create_timing_cache(timing_cache_data) else: # Create an empty timing cache by default so it will be populated during engine build. # This way, consumers of CreateConfig have the option to use the cache later. cache = config.create_timing_cache(b"") except AttributeError: if self.timing_cache_path: trt_util.fail_unavailable("load_timing_cache in CreateConfig") else: config.set_timing_cache(cache, ignore_mismatch=False) if self.algorithm_selector is not None: def set_algo_selector(): config.algorithm_selector = self.algorithm_selector try_run(set_algo_selector, "algorithm_selector") return config
def validate(run_results, check_inf=None, check_nan=None, fail_fast=None): """ Checks output validity. Args: run_results (Dict[str, List[IterationResult]]): The result of Comparator.run(). check_inf (bool): Whether to fail on Infs. Defaults to False. check_nan (bool): Whether to fail on NaNs. Defaults to True. fail_fast (bool): Whether to fail after the first invalid value. Defaults to False. Returns: bool: True if all outputs were valid, False otherwise. """ check_inf = util.default(check_inf, False) check_nan = util.default(check_nan, True) fail_fast = util.default(fail_fast, False) def is_finite(output): non_finite = np.logical_not(np.isfinite(output)) if np.any(non_finite): G_LOGGER.error( "Inf Detected | One or more non-finite values were encountered in this output" ) G_LOGGER.info( "Note: Use -vv or set logging verbosity to EXTRA_VERBOSE to display non-finite values", mode=LogMode.ONCE, ) G_LOGGER.extra_verbose( "Note: non-finite values at:\n{:}".format(non_finite)) G_LOGGER.extra_verbose("Note: non-finite values:\n{:}".format( output[non_finite])) return False return True def is_not_nan(output): nans = np.isnan(output) if np.any(nans): G_LOGGER.error( "NaN Detected | One or more NaNs were encountered in this output" ) G_LOGGER.info( "Note: Use -vv or set logging verbosity to EXTRA_VERBOSE to display locations of NaNs", mode=LogMode.ONCE, ) G_LOGGER.extra_verbose("Note: NaNs at:\n{:}".format(nans)) return False return True def validate_output(runner_name, output_name, output): G_LOGGER.start( "{:35} | Validating output: {:} (check_inf={:}, check_nan={:})" .format(runner_name, output_name, check_inf, check_nan)) with G_LOGGER.indent(): comp_util.log_output_stats(output) output_valid = True if check_nan: output_valid &= is_not_nan(output) if check_inf: output_valid &= is_finite(output) if output_valid: G_LOGGER.finish( "PASSED | Output: {:} is valid".format(output_name)) else: G_LOGGER.error( "FAILED | Errors detected in output: {:}".format( output_name)) return output_valid all_valid = True G_LOGGER.start("Output Validation | Runners: {:}".format( list(run_results.keys()))) with G_LOGGER.indent(): for runner_name, results in run_results: for result in results: for output_name, output in result.items(): all_valid &= validate_output(runner_name, output_name, output) if fail_fast and not all_valid: return False if all_valid: G_LOGGER.finish("PASSED | Output Validation") else: G_LOGGER.error("FAILED | Output Validation") return all_valid
def compare_accuracy(run_results, fail_fast=False, comparisons=None, compare_func=None): """ Args: run_results (RunResults): The result of Comparator.run() fail_fast (bool): Whether to exit after the first failure comparisons (List[Tuple[int, int]]): Comparisons to perform, specified by runner indexes. For example, [(0, 1), (1, 2)] would compare the first runner with the second, and the second with the third. By default, this compares each result to the subsequent one. compare_func (Callable(IterationResult, IterationResult) -> OrderedDict[str, bool]): A function that takes in two IterationResults, and returns a dictionary that maps output names to a boolean (or anything convertible to a boolean) indicating whether outputs matched. The order of arguments to this function is guaranteed to be the same as the ordering of the tuples contained in `comparisons`. Returns: AccuracyResult: A summary of the results of the comparisons. The order of the keys (i.e. runner pairs) is guaranteed to be the same as the order of `comparisons`. For more details, see the AccuracyResult docstring (e.g. help(AccuracyResult)). """ def find_mismatched(match_dict): return [ name for name, matched in match_dict.items() if not bool(matched) ] compare_func = util.default(compare_func, CompareFunc.simple()) comparisons = util.default(comparisons, Comparator.default_comparisons(run_results)) accuracy_result = AccuracyResult() for runner0_index, runner1_index in comparisons: (runner0_name, results0), ( runner1_name, results1 ) = run_results[runner0_index], run_results[runner1_index] G_LOGGER.start("Accuracy Comparison | {:} vs. {:}".format( runner0_name, runner1_name)) with G_LOGGER.indent(): runner_pair = (runner0_name, runner1_name) accuracy_result[runner_pair] = [] num_iters = min(len(results0), len(results1)) for iteration, (result0, result1) in enumerate(zip(results0, results1)): if num_iters > 1: G_LOGGER.info("Iteration: {:}".format(iteration)) with contextlib.ExitStack() as stack: if num_iters > 1: stack.enter_context(G_LOGGER.indent()) iteration_match_dict = compare_func(result0, result1) accuracy_result[runner_pair].append( iteration_match_dict) mismatched_outputs = find_mismatched( iteration_match_dict) if fail_fast and mismatched_outputs: return accuracy_result G_LOGGER.extra_verbose( "Finished comparing {:} with {:}".format( runner0_name, runner1_name, )) passed, _, total = accuracy_result.stats(runner_pair) pass_rate = accuracy_result.percentage(runner_pair) * 100.0 if num_iters > 1 or len(comparisons) > 1: msg = "Accuracy Summary | {:} vs. {:} | Passed: {:}/{:} iterations | Pass Rate: {:}%".format( runner0_name, runner1_name, passed, total, pass_rate) if passed == total: G_LOGGER.finish(msg) else: G_LOGGER.error(msg) return accuracy_result
def compare_output(iter_result0, iter_result1): """ Compare the outputs of two runners from a single iteration. This function will always iterate over the output names of the first IterationResult, and attempt to find corresponding output names in the second. If no corresponding output name is found, the output is skipped. If all output names are skipped, then this function raises an error. Args: iter_result0 (IterationResult): The result of the first runner. iter_result1 (IterationResult): The result of the second runner. Returns: OrderedDict[str, OutputCompareResult]: The name of the outputs compared, derived from the first IterationResult, and whether they matched. If an output name is not found, it is omitted from this dictionary. Raises: PolygraphyException: If all output names are skipped, and thus no outputs are compared. """ def check_dict(dct, dict_name): if isinstance(dct, dict): util.check_dict_contains( dct, set(iter_result0.keys()) | set(iter_result1.keys()) | {""}, check_missing=False, dict_name=dict_name, ) check_dict(rtol, "the rtol dictionary") check_dict(atol, "the atol dictionary") check_dict(check_error_stat, "the check_error_stat dictionary") output_status = OrderedDict( ) # OrderedDict[str, bool] Maps output names to whether they matched. if not check_shapes: G_LOGGER.info( "Strict shape checking disabled. Will attempt to match output shapes before comparisons" ) def default_find_output_func(output_name, index, iter_result): found_name = util.find_in_dict(output_name, iter_result, index) if found_name is None: return None elif found_name != output_name: exact_match = util.find_in_dict(found_name, iter_result0) if exact_match == found_name: G_LOGGER.verbose( "Will not compare {:} with {:}, since the former already has an exact match: {:}" .format(found_name, output_name, exact_match)) return None # If the found output is being compared against another output already, skip this non-exact match G_LOGGER.warning( "Output names did not match exactly. Assuming {:} output: {:} " "corresponds to output: {:}".format( iter_result.runner_name, found_name, output_name)) return [found_name] nonlocal find_output_func find_output_func = util.default(find_output_func, default_find_output_func) for index, (out0_name, output0) in enumerate(iter_result0.items()): out1_names = util.default( find_output_func(out0_name, index, iter_result1), []) if len(out1_names) > 1: G_LOGGER.info( "Will attempt to compare output: '{:}' [{:}] with multiple outputs: '{:}' [{:}]" .format(out0_name, iter_result0.runner_name, list(out1_names), iter_result1.runner_name)) for out1_name in out1_names: if out1_name is None or out1_name not in iter_result1: G_LOGGER.warning( "For output: '{:}' [{:}], skipping corresponding output: '{:}' [{:}], " "since the output was not found".format( out0_name, iter_result0.runner_name, out1_name, iter_result1.runner_name)) continue per_out_atol = util.value_or_from_dict( atol, out0_name, default_atol) per_out_rtol = util.value_or_from_dict( rtol, out0_name, default_rtol) per_out_err_stat = util.value_or_from_dict( check_error_stat, out0_name, default_error_stat) output1 = iter_result1[out1_name] G_LOGGER.start( "Comparing Output: '{:}' (dtype={:}, shape={:}) with '{:}' (dtype={:}, shape={:}) | " "Tolerance: [abs={:.5g}, rel={:.5g}] | Checking {:} error" .format( out0_name, output0.dtype, output0.shape, out1_name, output1.dtype, output1.shape, per_out_atol, per_out_rtol, per_out_err_stat, )) G_LOGGER.extra_verbose( "Note: Comparing {:} vs. {:}".format( iter_result0.runner_name, iter_result1.runner_name)) with G_LOGGER.indent(): if check_shapes and output0.shape != output1.shape: G_LOGGER.error( "Will not compare outputs of different shapes. Note: Output shapes are " "{:} and {:}.".format(output0.shape, output1.shape)) G_LOGGER.error( "Note: Use --no-shape-check or set check_shapes=False to " "attempt to compare values anyway.", mode=LogMode.ONCE, ) outputs_match = False else: output1 = util.try_match_shape( output1, output0.shape) output0 = output0.reshape(output1.shape) outputs_match = check_outputs_match( output0, out0_name, output1, out1_name, per_out_rtol=per_out_rtol, per_out_atol=per_out_atol, per_out_err_stat=per_out_err_stat, runner0_name=iter_result0.runner_name, runner1_name=iter_result1.runner_name, ) output_status[out0_name] = outputs_match if fail_fast and not outputs_match: return output_status mismatched_output_names = [ name for name, matched in output_status.items() if not matched ] if mismatched_output_names: G_LOGGER.error("FAILED | Mismatched outputs: {:}".format( mismatched_output_names)) else: G_LOGGER.finish( "PASSED | All outputs matched | Outputs: {:}".format( list(output_status.keys()))) # This is useful for catching cases were Polygraphy does something wrong with the runner output buffers if not output_status and (bool(iter_result0.keys()) or bool(iter_result1.keys())): r0_name = iter_result0.runner_name r0_outs = list(iter_result0.keys()) r1_name = iter_result1.runner_name r1_outs = list(iter_result1.keys()) G_LOGGER.critical( "All outputs were skipped, no common outputs found! Note:\n{:} outputs: " "{:}\n{:} outputs: {:}".format(r0_name, r0_outs, r1_name, r1_outs)) return output_status