コード例 #1
0
ファイル: comparator.py プロジェクト: leo-XUKANG/TensorRT-1
    def compare_accuracy(run_results,
                         fail_fast=False,
                         comparisons=None,
                         compare_func=None):
        """
        Args:
            run_results (RunResults): The result of Comparator.run()


            fail_fast (bool): Whether to exit after the first failure
            comparisons (List[Tuple[str, str]]):
                    Comparisons to perform, specified by runner names. For example, [(r0, r1), (r1, r2)]
                    would compare the runner named r0 with r1, and r1 with r2.
                    By default, this compares each result to the subsequent one.
            compare_func (Callable(IterationResult, IterationResult) -> OrderedDict[str, bool]):
                    A function that takes in two IterationResults, and returns a dictionary that maps output
                    names to a boolean (or anything convertible to a boolean) indicating whether outputs matched.
                    The order of arguments to this function is guaranteed to be the same as the ordering of the
                    tuples contained in `comparisons`.

        Returns:
            AccuracyResult:
                    A summary of the results of the comparisons. The order of the keys (i.e. runner pairs) is
                    guaranteed to be the same as the order of `comparisons`. For more details, see the AccuracyResult
                    docstring (e.g. help(AccuracyResult)).
        """
        def find_mismatched(match_dict):
            return [
                name for name, matched in match_dict.items()
                if not bool(matched)
            ]

        compare_func = misc.default_value(compare_func,
                                          CompareFunc.basic_compare_func())
        comparisons = misc.default_value(
            comparisons, Comparator.default_comparisons(run_results))

        accuracy_result = AccuracyResult()
        for runner0_name, runner1_name in comparisons:
            G_LOGGER.info("Accuracy Comparison | {:} vs. {:}".format(
                runner0_name, runner1_name))
            with G_LOGGER.indent():
                results0, results1 = run_results[runner0_name], run_results[
                    runner1_name]
                runner_pair = (runner0_name, runner1_name)
                accuracy_result[runner_pair] = []

                num_iters = min(len(results0), len(results1))
                for iteration, (result0,
                                result1) in enumerate(zip(results0, results1)):
                    if num_iters > 1:
                        G_LOGGER.info("Iteration: {:}".format(iteration))
                    with contextlib.ExitStack() as stack:
                        if num_iters > 1:
                            stack.enter_context(G_LOGGER.indent())
                        iteration_match_dict = compare_func(result0, result1)
                        accuracy_result[runner_pair].append(
                            iteration_match_dict)

                    mismatched_outputs = find_mismatched(iteration_match_dict)
                    if fail_fast and mismatched_outputs:
                        return accuracy_result

                G_LOGGER.extra_verbose(
                    "Finished comparing {:} with {:}".format(
                        runner0_name,
                        runner1_name,
                    ))

                passed, failed, total = accuracy_result.stats(runner_pair)
                pass_rate = accuracy_result.percentage(runner_pair) * 100.0
                if num_iters > 1 or len(comparisons) > 1:
                    msg = "Accuracy Summary | {:} vs. {:} | Passed: {:}/{:} iterations | Pass Rate: {:}%".format(
                        runner0_name, runner1_name, passed, total, pass_rate)
                    if passed == total:
                        G_LOGGER.success(msg)
                    else:
                        G_LOGGER.error(msg)
        return accuracy_result
コード例 #2
0
ファイル: comparator.py プロジェクト: leo-XUKANG/TensorRT-1
    def validate(run_results,
                 check_finite=None,
                 check_nan=None,
                 fail_fast=None):
        """
        Checks output validity.

        Args:
            run_results (Dict[str, List[IterationResult]]): The result of Comparator.run().
            check_finite (bool): Whether to fail on non-finite values. Defaults to False.
            check_nan (bool): Whether to fail on NaNs. Defaults to True.
            fail_fast (bool): Whether to fail after the first invalid value. Defaults to False.

        Returns:
            bool: True if all outputs were valid, False otherwise.
        """
        check_finite = misc.default_value(check_finite, False)
        check_nan = misc.default_value(check_nan, True)
        fail_fast = misc.default_value(fail_fast, False)

        def is_finite(output):
            non_finite = np.logical_not(np.isfinite(output))
            if np.any(non_finite):
                G_LOGGER.error("Encountered one or more non-finite values")
                G_LOGGER.error(
                    "Note: Use -vv or set logging verbosity to EXTRA_VERBOSE to display non-finite values",
                    mode=LogMode.ONCE)
                G_LOGGER.extra_verbose(
                    "Note: non-finite values at:\n{:}".format(non_finite))
                G_LOGGER.extra_verbose("Note: non-finite values:\n{:}".format(
                    output[non_finite]))
                return False
            return True

        def is_not_nan(output):
            nans = np.isnan(output)
            if np.any(nans):
                G_LOGGER.error("Encountered one or more NaNs")
                G_LOGGER.error(
                    "Note: Use -vv or set logging verbosity to EXTRA_VERBOSE to display locations of NaNs",
                    mode=LogMode.ONCE)
                G_LOGGER.extra_verbose("Note: NaNs at:\n{:}".format(nans))
                return False
            return True

        all_valid = True
        for runner_name, results in run_results.items():
            for result in results:
                for output_name, output in result.items():
                    G_LOGGER.info(
                        "Runner: {:40} | Validating output: {:} (check_finite={:}, check_nan={:})"
                        .format(runner_name, output_name, check_finite,
                                check_nan))

                    output_valid = True
                    with G_LOGGER.indent():
                        if check_nan:
                            output_valid &= is_not_nan(output)
                        if check_finite:
                            output_valid &= is_finite(output)

                        all_valid &= output_valid

                        if output_valid:
                            G_LOGGER.success(
                                "Runner: {:40} | Output: {:} is valid".format(
                                    runner_name, output_name))
                        else:
                            G_LOGGER.error(
                                "Runner: {:40} | Errors detected in output: {:}"
                                .format(runner_name, output_name))
                            if fail_fast:
                                return False

        if all_valid:
            G_LOGGER.success("Validation passed")
        else:
            G_LOGGER.error("Validation failed")
        return all_valid
コード例 #3
0
ファイル: compare.py プロジェクト: leo-XUKANG/TensorRT-1
            def check_outputs_match(out0, out0_name, out1, out1_name):
                def compute_max(buffer):
                    if misc.is_empty_shape(buffer.shape):
                        return 0
                    return np.amax(buffer)

                # Returns index of max value
                def compute_argmax(buffer):
                    if misc.is_empty_shape(buffer.shape):
                        return 0
                    return np.unravel_index(np.argmax(buffer), buffer.shape)

                def compute_min(buffer):
                    if misc.is_empty_shape(buffer.shape):
                        return 0
                    return np.amin(buffer)

                # Returns index of min value
                def compute_argmin(buffer):
                    if misc.is_empty_shape(buffer.shape):
                        return 0
                    return np.unravel_index(np.argmin(buffer), buffer.shape)

                def compute_mean(buffer):
                    if misc.is_empty_shape(buffer.shape):
                        return 0
                    return np.mean(buffer)


                def compute_required():
                    # The purpose of this function is to determine the minimum tolerances such that
                    # the outputs would be considered a match.
                    # The NumPy formula for np.isclose is absolute(out0 - out1) <= (atol + rtol * absolute(out1))
                    # So, for both absolute/relative tolerance, given either one,
                    # we can compute the required value for the other:
                    # atol = absolute(out0 - out1)
                    # atol_if_rtol = absolute(out0 - out1)  - rtol * absolute(out1)
                    # rtol = (absolute(out0 - out1) - atol) / absolute(out1)
                    if np.issubdtype(out0.dtype, np.bool_) and np.issubdtype(out1.dtype, np.bool_):
                        absdiff = np.logical_xor(out0, out1)
                    else:
                        absdiff = np.abs(out0 - out1)
                    absout1 = np.abs(out1)
                    required_atol = max(compute_max(absdiff), 0.0)
                    required_atol_if_rtol = max(compute_max(absdiff - rtol * absout1), 0.0)
                    # Suppress divide by 0 warnings
                    with np.testing.suppress_warnings() as sup:
                        sup.filter(RuntimeWarning)
                        required_rtol = max(compute_max((absdiff - atol) / absout1), 0.0)
                    return required_atol, required_atol_if_rtol, required_rtol


                def log_mismatches(mismatches):
                    try:
                        with G_LOGGER.indent():
                            G_LOGGER.super_verbose("Mismatches at:\n" + str(mismatches))
                            G_LOGGER.extra_verbose("Runner: {:40} | Mismatched values:\n{:}".format(iter_result0.runner_name, out0[mismatches]))
                            G_LOGGER.extra_verbose("Runner: {:40} | Mismatched values:\n{:}".format(iter_result1.runner_name, out1[mismatches]))
                    except:
                        G_LOGGER.warning("Failing to log mismatches - this may be because the outputs are of different shapes")


                try:
                    mismatches = np.logical_not(np.isclose(output0, output1, rtol=rtol, atol=atol))
                except Exception as err:
                    G_LOGGER.warning("Failed to compare outputs with:\n{:}\nSkipping".format(err))
                    return False

                G_LOGGER.super_verbose("Runner: {:40} | Output: {:} (dtype={:}, shape={:}):\n{:}".format(
                                            iter_result0.runner_name, out0_name, out0.dtype, out0.shape, misc.indent_block(out0)))
                G_LOGGER.super_verbose("Runner: {:40} | Output: {:} (dtype={:}, shape={:}):\n{:}".format(
                                            iter_result1.runner_name, out1_name, out1.dtype, out1.shape, misc.indent_block(out1)))

                failed = np.any(mismatches)

                try:
                    required_atol, required_atol_if_rtol, required_rtol = compute_required()
                except Exception as err:
                    required_atol, required_atol_if_rtol, required_rtol = None, None, None
                    G_LOGGER.warning("Could not determine required tolerances due to an error:\n{:}".format(err))
                    log_msg = ""
                else:
                    log_msg = "Required tolerances: [atol={:.5g}] OR [rtol={:.5g}, atol={:.5g}] OR [rtol={:.5g}, atol={:.5g}]\n".format(
                                    required_atol, rtol, required_atol_if_rtol, required_rtol, atol)

                log_msg += "Runner: {:40} | Stats: mean={:.5g}, min={:.5g} at {:}, max={:.5g} at {:}\n".format(
                                iter_result0.runner_name, compute_mean(out0), compute_min(out0), compute_argmin(out0), compute_max(out0), compute_argmax(out0))
                log_msg += "Runner: {:40} | Stats: mean={:.5g}, min={:.5g} at {:}, max={:.5g} at {:}\n".format(
                                iter_result1.runner_name, compute_mean(out1), compute_min(out1), compute_argmin(out1), compute_max(out1), compute_argmax(out1))

                if failed:
                    log_mismatches(mismatches)
                    G_LOGGER.info(log_msg)
                    G_LOGGER.error("FAILED | Difference exceeds tolerance (rtol={:}, atol={:})".format(rtol, atol))
                else:
                    G_LOGGER.verbose(log_msg)
                    G_LOGGER.success("PASSED | Difference is within tolerance (rtol={:}, atol={:})".format(rtol, atol))

                G_LOGGER.extra_verbose("Finished comparing: '{:}' (dtype={:}, shape={:}) [{:}] and '{:}' (dtype={:}, shape={:}) [{:}]"
                                .format(out0_name, out0.dtype, out0.shape, iter_result0.runner_name, out1_name, out1.dtype, out1.shape, iter_result1.runner_name))
                return OutputCompareResult(not failed, required_atol, required_rtol)