def compare_accuracy(run_results, fail_fast=False, comparisons=None, compare_func=None): """ Args: run_results (RunResults): The result of Comparator.run() fail_fast (bool): Whether to exit after the first failure comparisons (List[Tuple[str, str]]): Comparisons to perform, specified by runner names. For example, [(r0, r1), (r1, r2)] would compare the runner named r0 with r1, and r1 with r2. By default, this compares each result to the subsequent one. compare_func (Callable(IterationResult, IterationResult) -> OrderedDict[str, bool]): A function that takes in two IterationResults, and returns a dictionary that maps output names to a boolean (or anything convertible to a boolean) indicating whether outputs matched. The order of arguments to this function is guaranteed to be the same as the ordering of the tuples contained in `comparisons`. Returns: AccuracyResult: A summary of the results of the comparisons. The order of the keys (i.e. runner pairs) is guaranteed to be the same as the order of `comparisons`. For more details, see the AccuracyResult docstring (e.g. help(AccuracyResult)). """ def find_mismatched(match_dict): return [ name for name, matched in match_dict.items() if not bool(matched) ] compare_func = misc.default_value(compare_func, CompareFunc.basic_compare_func()) comparisons = misc.default_value( comparisons, Comparator.default_comparisons(run_results)) accuracy_result = AccuracyResult() for runner0_name, runner1_name in comparisons: G_LOGGER.info("Accuracy Comparison | {:} vs. {:}".format( runner0_name, runner1_name)) with G_LOGGER.indent(): results0, results1 = run_results[runner0_name], run_results[ runner1_name] runner_pair = (runner0_name, runner1_name) accuracy_result[runner_pair] = [] num_iters = min(len(results0), len(results1)) for iteration, (result0, result1) in enumerate(zip(results0, results1)): if num_iters > 1: G_LOGGER.info("Iteration: {:}".format(iteration)) with contextlib.ExitStack() as stack: if num_iters > 1: stack.enter_context(G_LOGGER.indent()) iteration_match_dict = compare_func(result0, result1) accuracy_result[runner_pair].append( iteration_match_dict) mismatched_outputs = find_mismatched(iteration_match_dict) if fail_fast and mismatched_outputs: return accuracy_result G_LOGGER.extra_verbose( "Finished comparing {:} with {:}".format( runner0_name, runner1_name, )) passed, failed, total = accuracy_result.stats(runner_pair) pass_rate = accuracy_result.percentage(runner_pair) * 100.0 if num_iters > 1 or len(comparisons) > 1: msg = "Accuracy Summary | {:} vs. {:} | Passed: {:}/{:} iterations | Pass Rate: {:}%".format( runner0_name, runner1_name, passed, total, pass_rate) if passed == total: G_LOGGER.success(msg) else: G_LOGGER.error(msg) return accuracy_result
def validate(run_results, check_finite=None, check_nan=None, fail_fast=None): """ Checks output validity. Args: run_results (Dict[str, List[IterationResult]]): The result of Comparator.run(). check_finite (bool): Whether to fail on non-finite values. Defaults to False. check_nan (bool): Whether to fail on NaNs. Defaults to True. fail_fast (bool): Whether to fail after the first invalid value. Defaults to False. Returns: bool: True if all outputs were valid, False otherwise. """ check_finite = misc.default_value(check_finite, False) check_nan = misc.default_value(check_nan, True) fail_fast = misc.default_value(fail_fast, False) def is_finite(output): non_finite = np.logical_not(np.isfinite(output)) if np.any(non_finite): G_LOGGER.error("Encountered one or more non-finite values") G_LOGGER.error( "Note: Use -vv or set logging verbosity to EXTRA_VERBOSE to display non-finite values", mode=LogMode.ONCE) G_LOGGER.extra_verbose( "Note: non-finite values at:\n{:}".format(non_finite)) G_LOGGER.extra_verbose("Note: non-finite values:\n{:}".format( output[non_finite])) return False return True def is_not_nan(output): nans = np.isnan(output) if np.any(nans): G_LOGGER.error("Encountered one or more NaNs") G_LOGGER.error( "Note: Use -vv or set logging verbosity to EXTRA_VERBOSE to display locations of NaNs", mode=LogMode.ONCE) G_LOGGER.extra_verbose("Note: NaNs at:\n{:}".format(nans)) return False return True all_valid = True for runner_name, results in run_results.items(): for result in results: for output_name, output in result.items(): G_LOGGER.info( "Runner: {:40} | Validating output: {:} (check_finite={:}, check_nan={:})" .format(runner_name, output_name, check_finite, check_nan)) output_valid = True with G_LOGGER.indent(): if check_nan: output_valid &= is_not_nan(output) if check_finite: output_valid &= is_finite(output) all_valid &= output_valid if output_valid: G_LOGGER.success( "Runner: {:40} | Output: {:} is valid".format( runner_name, output_name)) else: G_LOGGER.error( "Runner: {:40} | Errors detected in output: {:}" .format(runner_name, output_name)) if fail_fast: return False if all_valid: G_LOGGER.success("Validation passed") else: G_LOGGER.error("Validation failed") return all_valid
def check_outputs_match(out0, out0_name, out1, out1_name): def compute_max(buffer): if misc.is_empty_shape(buffer.shape): return 0 return np.amax(buffer) # Returns index of max value def compute_argmax(buffer): if misc.is_empty_shape(buffer.shape): return 0 return np.unravel_index(np.argmax(buffer), buffer.shape) def compute_min(buffer): if misc.is_empty_shape(buffer.shape): return 0 return np.amin(buffer) # Returns index of min value def compute_argmin(buffer): if misc.is_empty_shape(buffer.shape): return 0 return np.unravel_index(np.argmin(buffer), buffer.shape) def compute_mean(buffer): if misc.is_empty_shape(buffer.shape): return 0 return np.mean(buffer) def compute_required(): # The purpose of this function is to determine the minimum tolerances such that # the outputs would be considered a match. # The NumPy formula for np.isclose is absolute(out0 - out1) <= (atol + rtol * absolute(out1)) # So, for both absolute/relative tolerance, given either one, # we can compute the required value for the other: # atol = absolute(out0 - out1) # atol_if_rtol = absolute(out0 - out1) - rtol * absolute(out1) # rtol = (absolute(out0 - out1) - atol) / absolute(out1) if np.issubdtype(out0.dtype, np.bool_) and np.issubdtype(out1.dtype, np.bool_): absdiff = np.logical_xor(out0, out1) else: absdiff = np.abs(out0 - out1) absout1 = np.abs(out1) required_atol = max(compute_max(absdiff), 0.0) required_atol_if_rtol = max(compute_max(absdiff - rtol * absout1), 0.0) # Suppress divide by 0 warnings with np.testing.suppress_warnings() as sup: sup.filter(RuntimeWarning) required_rtol = max(compute_max((absdiff - atol) / absout1), 0.0) return required_atol, required_atol_if_rtol, required_rtol def log_mismatches(mismatches): try: with G_LOGGER.indent(): G_LOGGER.super_verbose("Mismatches at:\n" + str(mismatches)) G_LOGGER.extra_verbose("Runner: {:40} | Mismatched values:\n{:}".format(iter_result0.runner_name, out0[mismatches])) G_LOGGER.extra_verbose("Runner: {:40} | Mismatched values:\n{:}".format(iter_result1.runner_name, out1[mismatches])) except: G_LOGGER.warning("Failing to log mismatches - this may be because the outputs are of different shapes") try: mismatches = np.logical_not(np.isclose(output0, output1, rtol=rtol, atol=atol)) except Exception as err: G_LOGGER.warning("Failed to compare outputs with:\n{:}\nSkipping".format(err)) return False G_LOGGER.super_verbose("Runner: {:40} | Output: {:} (dtype={:}, shape={:}):\n{:}".format( iter_result0.runner_name, out0_name, out0.dtype, out0.shape, misc.indent_block(out0))) G_LOGGER.super_verbose("Runner: {:40} | Output: {:} (dtype={:}, shape={:}):\n{:}".format( iter_result1.runner_name, out1_name, out1.dtype, out1.shape, misc.indent_block(out1))) failed = np.any(mismatches) try: required_atol, required_atol_if_rtol, required_rtol = compute_required() except Exception as err: required_atol, required_atol_if_rtol, required_rtol = None, None, None G_LOGGER.warning("Could not determine required tolerances due to an error:\n{:}".format(err)) log_msg = "" else: log_msg = "Required tolerances: [atol={:.5g}] OR [rtol={:.5g}, atol={:.5g}] OR [rtol={:.5g}, atol={:.5g}]\n".format( required_atol, rtol, required_atol_if_rtol, required_rtol, atol) log_msg += "Runner: {:40} | Stats: mean={:.5g}, min={:.5g} at {:}, max={:.5g} at {:}\n".format( iter_result0.runner_name, compute_mean(out0), compute_min(out0), compute_argmin(out0), compute_max(out0), compute_argmax(out0)) log_msg += "Runner: {:40} | Stats: mean={:.5g}, min={:.5g} at {:}, max={:.5g} at {:}\n".format( iter_result1.runner_name, compute_mean(out1), compute_min(out1), compute_argmin(out1), compute_max(out1), compute_argmax(out1)) if failed: log_mismatches(mismatches) G_LOGGER.info(log_msg) G_LOGGER.error("FAILED | Difference exceeds tolerance (rtol={:}, atol={:})".format(rtol, atol)) else: G_LOGGER.verbose(log_msg) G_LOGGER.success("PASSED | Difference is within tolerance (rtol={:}, atol={:})".format(rtol, atol)) G_LOGGER.extra_verbose("Finished comparing: '{:}' (dtype={:}, shape={:}) [{:}] and '{:}' (dtype={:}, shape={:}) [{:}]" .format(out0_name, out0.dtype, out0.shape, iter_result0.runner_name, out1_name, out1.dtype, out1.shape, iter_result1.runner_name)) return OutputCompareResult(not failed, required_atol, required_rtol)