Exemplo n.º 1
0
        def coerce_cached_input(index, name, dtype, shape):
            cached_feed_dict = self.cache[iteration]
            cached_name = util.find_in_dict(name, cached_feed_dict, index)
            assert cached_name is not None

            if cached_name != name:
                G_LOGGER.warning("Input tensor: {:} | Cached buffer name ({:}) does not match input name ({:}).".format(
                                    name, cached_name, name))

            buffer = cached_feed_dict[cached_name]

            if dtype != buffer.dtype:
                G_LOGGER.warning("Input tensor: {:} | Cached buffer dtype ({:}) does not match input dtype ({:}), attempting cast. ".format(
                                    name, buffer.dtype, np.dtype(dtype).name))

                type_info = None
                if np.issubdtype(dtype, np.integer):
                    type_info = np.iinfo(np.dtype(dtype))
                elif np.issubdtype(dtype, np.floating):
                    type_info = np.finfo(np.dtype(dtype))

                if type_info is not None and np.any((buffer < type_info.min) | (buffer > type_info.max)):
                    G_LOGGER.warning("Some values in this input arre out of range of {:}. Unexpected behavior may ensue!".format(dtype))
                buffer = buffer.astype(dtype)

            if not util.is_valid_shape_override(buffer.shape, shape):
                G_LOGGER.warning("Input tensor: {:} | Cached buffer shape ({:}) does not match input shape ({:}), attempting reshape. ".format(
                                    name, buffer.shape, shape))
                buffer = util.try_match_shape(buffer, shape)

            assert buffer.dtype == dtype and util.is_valid_shape_override(buffer.shape, shape)
            return buffer
Exemplo n.º 2
0
        def compare_output(iter_result0, iter_result1):
            """
            Compare the outputs of two runners from a single iteration.

            This function will always iterate over the output names of the first IterationResult,
                and attempt to find corresponding output names in the second.
            If no corresponding output name is found, the output is skipped.
            If all output names are skipped, then this function raises an error.

            Args:
                iter_result0 (IterationResult): The result of the first runner.
                iter_result1 (IterationResult): The result of the second runner.

            Returns:
                OrderedDict[str, OutputCompareResult]:
                        The name of the outputs compared, derived from the first IterationResult,
                        and whether they matched. If an output name is not found, it is omitted from this dictionary.

            Raises:
                PolygraphyException: If all output names are skipped, and thus no outputs are compared.
            """
            def check_dict(dct, dict_name):
                if isinstance(dct, dict):
                    util.check_dict_contains(dct, set(iter_result0.keys()) | set(iter_result1.keys()) | set([""]),
                                             check_missing=False, dict_name=dict_name)


            check_dict(rtol, "the rtol dictionary")
            check_dict(atol, "the atol dictionary")
            check_dict(check_error_stat, "the chcek_error_stat dictionary")


            # Returns whether the outputs match
            def check_outputs_match(out0, out0_name, out1, out1_name, per_out_rtol, per_out_atol, per_out_err_stat):
                VALID_CHECK_ERROR_STATS = ["max", "mean", "median", "elemwise"]
                if per_out_err_stat not in VALID_CHECK_ERROR_STATS:
                    G_LOGGER.critical("Invalid choice for check_error_stat: {:}.\n"
                                      "Note: Valid choices are: {:}".format(per_out_err_stat, VALID_CHECK_ERROR_STATS))

                G_LOGGER.super_verbose("{:35} | Output: {:} (dtype={:}, shape={:}):\n{:}".format(
                                            iter_result0.runner_name, out0_name, out0.dtype, out0.shape, util.indent_block(out0)))
                G_LOGGER.super_verbose("{:35} | Output: {:} (dtype={:}, shape={:}):\n{:}".format(
                                            iter_result1.runner_name, out1_name, out1.dtype, out1.shape, util.indent_block(out1)))

                # Check difference vs. tolerances
                if np.issubdtype(out0.dtype, np.bool_) and np.issubdtype(out1.dtype, np.bool_):
                    absdiff = np.logical_xor(out0, out1)
                else:
                    absdiff = np.abs(out0 - out1)

                absout1 = np.abs(out1)
                with np.testing.suppress_warnings() as sup:
                    sup.filter(RuntimeWarning)
                    reldiff = absdiff / absout1

                max_absdiff = comp_util.compute_max(absdiff)
                mean_absdiff = comp_util.compute_mean(absdiff)
                median_absdiff = comp_util.compute_median(absdiff)
                max_reldiff = comp_util.compute_max(reldiff)
                mean_reldiff = comp_util.compute_mean(reldiff)
                median_reldiff = comp_util.compute_median(reldiff)

                max_elemwiseabs = "Unknown"
                max_elemwiserel = "Unknown"

                if per_out_err_stat == "mean":
                    failed = mean_absdiff > per_out_atol and (np.isnan(mean_reldiff) or mean_reldiff > per_out_rtol)
                elif per_out_err_stat == "median":
                    failed = median_absdiff > per_out_atol and (np.isnan(median_reldiff) or median_reldiff > per_out_rtol)
                elif per_out_err_stat == "max":
                    failed = max_absdiff > per_out_atol and (np.isnan(max_reldiff) or max_reldiff > per_out_rtol)
                else:
                    assert per_out_err_stat == "elemwise", "This branch should be unreachable unless per_out_err_stat is 'elemwise'"
                    mismatches = (absdiff > per_out_atol) & (reldiff > per_out_rtol)

                    failed = np.any(mismatches)
                    try:
                        # Special because we need to account for tolerances too.
                        max_elemwiseabs = comp_util.compute_max(absdiff[mismatches])
                        max_elemwiserel = comp_util.compute_max(reldiff[mismatches])

                        with G_LOGGER.indent():
                            G_LOGGER.super_verbose("Mismatched indices:\n{:}".format(np.argwhere(mismatches)))
                            G_LOGGER.extra_verbose("{:35} | Mismatched values:\n{:}".format(iter_result0.runner_name, out0[mismatches]))
                            G_LOGGER.extra_verbose("{:35} | Mismatched values:\n{:}".format(iter_result1.runner_name, out1[mismatches]))
                    except Exception as err:
                        G_LOGGER.warning("Failing to log mismatches.\nNote: Error was: {:}".format(err))

                # Log information about the outputs
                hist_bin_range = (min(comp_util.compute_min(out0), comp_util.compute_min(out1)),
                                  max(comp_util.compute_max(out0), comp_util.compute_max(out1)))
                comp_util.log_output_stats(out0, failed, iter_result0.runner_name + ": " + out0_name, hist_range=hist_bin_range)
                comp_util.log_output_stats(out1, failed, iter_result1.runner_name + ": " + out1_name, hist_range=hist_bin_range)

                G_LOGGER.info("Error Metrics: {:}".format(out0_name))
                with G_LOGGER.indent():
                    def req_tol(mean_diff, median_diff, max_diff, elemwise_diff):
                        return {
                            "mean": mean_diff,
                            "median": median_diff,
                            "max": max_diff,
                            "elemwise": elemwise_diff,
                        }[per_out_err_stat]

                    G_LOGGER.info("Minimum Required Tolerance: {:} error | [abs={:.5g}] OR [rel={:.5g}]".format(
                                    per_out_err_stat,
                                    req_tol(mean_absdiff, median_absdiff, max_absdiff, max_elemwiseabs),
                                    req_tol(mean_reldiff, median_reldiff, max_reldiff, max_elemwiserel)))
                    comp_util.log_output_stats(absdiff, failed, "Absolute Difference")
                    comp_util.log_output_stats(reldiff, failed, "Relative Difference")

                # Finally show summary.
                if failed:
                    G_LOGGER.error("FAILED | Difference exceeds tolerance (rel={:}, abs={:})".format(per_out_rtol, per_out_atol))
                else:
                    G_LOGGER.finish("PASSED | Difference is within tolerance (rel={:}, abs={:})".format(per_out_rtol, per_out_atol))

                G_LOGGER.extra_verbose("Finished comparing: '{:}' (dtype={:}, shape={:}) [{:}] and '{:}' (dtype={:}, shape={:}) [{:}]"
                                .format(out0_name, out0.dtype, out0.shape, iter_result0.runner_name, out1_name, out1.dtype, out1.shape, iter_result1.runner_name))
                return OutputCompareResult(not failed, max_absdiff, max_reldiff, mean_absdiff, mean_reldiff, median_absdiff, median_reldiff)
                #
                # End: def check_outputs_match
                #

            output_status = OrderedDict() # OrderedDict[str, bool] Maps output names to whether they matched.

            if not check_shapes:
                G_LOGGER.info("Strict shape checking disabled. Will attempt to match output shapes before comparisons")


            def default_find_output_func(output_name, index, iter_result):
                found_name = util.find_in_dict(output_name, iter_result, index)
                if found_name is None:
                    return None
                elif found_name != output_name:
                    exact_match = util.find_in_dict(found_name, iter_result0)
                    if exact_match == found_name:
                        G_LOGGER.verbose("Will not compare {:} with {:}, since the former already has an exact match: {:}".format(
                                            found_name, output_name, exact_match))
                        return None # If the found output is being compared against another output already, skip this non-exact match
                    G_LOGGER.warning("Output names did not match exactly. Assuming {:} output: {:} "
                                    "corresponds to output: {:}".format(
                                        iter_result.runner_name, found_name, output_name))
                return [found_name]


            nonlocal find_output_func
            find_output_func = util.default(find_output_func, default_find_output_func)

            for index, (out0_name, output0) in enumerate(iter_result0.items()):
                out1_names = util.default(find_output_func(out0_name, index, iter_result1), [])

                if len(out1_names) > 1:
                    G_LOGGER.info("Will attempt to compare output: '{:}' [{:}] with multiple outputs: '{:}' [{:}]".format(
                                    out0_name, iter_result0.runner_name, list(out1_names), iter_result1.runner_name))

                for out1_name in out1_names:
                    if out1_name is None or out1_name not in iter_result1:
                        G_LOGGER.warning("For output: '{:}' [{:}], skipping corresponding output: '{:}' [{:}], "
                                         "since the output was not found".format(out0_name, iter_result0.runner_name,
                                                                                 out1_name, iter_result1.runner_name))
                        continue


                    def get_tol(tol_dict, default):
                        if isinstance(tol_dict, numbers.Number):
                            return tol_dict

                        if out0_name in tol_dict:
                            return tol_dict[out0_name]
                        elif "" in tol_dict:
                            return tol_dict[""]
                        return default


                    def get_error_stat():
                        if isinstance(check_error_stat, str):
                            return check_error_stat

                        if out0_name in check_error_stat:
                            return check_error_stat[out0_name]
                        elif "" in check_error_stat:
                            return  check_error_stat[""]
                        return default_error_stat


                    per_out_atol = get_tol(atol, default_atol)
                    per_out_rtol = get_tol(rtol, default_rtol)
                    per_out_err_stat = get_error_stat()

                    output1 = iter_result1[out1_name]
                    G_LOGGER.start("Comparing Output: '{:}' (dtype={:}, shape={:}) with '{:}' (dtype={:}, shape={:}) | "
                                   "Tolerance: [abs={:.5g}, rel={:.5g}] | Checking {:} error".format(
                                        out0_name, output0.dtype, output0.shape,
                                        out1_name, output1.dtype, output1.shape,
                                        per_out_atol, per_out_rtol, per_out_err_stat))
                    G_LOGGER.extra_verbose("Note: Comparing {:} vs. {:}".format(iter_result0.runner_name, iter_result1.runner_name))


                    with G_LOGGER.indent():
                        if check_shapes and output0.shape != output1.shape:
                            G_LOGGER.error("Will not compare outputs of different shapes. Note: Output shapes are "
                                           "{:} and {:}.".format(output0.shape, output1.shape))
                            G_LOGGER.error("Note: Use --no-strict-shape-checking or set check_shapes=False to "
                                           "attempt to compare values anyway.", mode=LogMode.ONCE)
                            outputs_match = False
                        else:
                            output1 = util.try_match_shape(output1, output0.shape)
                            output0 = output0.reshape(output1.shape)
                            outputs_match = check_outputs_match(output0, out0_name, output1, out1_name,
                                                                per_out_rtol=per_out_rtol, per_out_atol=per_out_atol,
                                                                per_out_err_stat=per_out_err_stat)

                        output_status[out0_name] = outputs_match
                        if fail_fast and not outputs_match:
                            return output_status


            mismatched_output_names = [name for name, matched in output_status.items() if not matched]
            if mismatched_output_names:
                G_LOGGER.error("FAILED | Mismatched outputs: {:}".format(mismatched_output_names))
            else:
                G_LOGGER.finish("PASSED | All outputs matched | Outputs: {:}".format(list(output_status.keys())))

            # This is useful for catching cases were Polygraphy does something wrong with the runner output buffers
            if not output_status and (bool(iter_result0.keys()) or bool(iter_result1.keys())):
                r0_name = iter_result0.runner_name
                r0_outs = list(iter_result0.keys())
                r1_name = iter_result1.runner_name
                r1_outs = list(iter_result1.keys())
                G_LOGGER.critical("All outputs were skipped, no common outputs found! Note:\n{:} outputs: "
                                  "{:}\n{:} outputs: {:}".format(r0_name, r0_outs, r1_name, r1_outs))

            return output_status
Exemplo n.º 3
0
        def compare_output(iter_result0, iter_result1):
            """
            Compare the outputs of two runners from a single iteration.

            This function will always iterate over the output names of the first IterationResult,
                and attempt to find corresponding output names in the second.
            If no corresponding output name is found, the output is skipped.
            If all output names are skipped, then this function raises an error.

            Args:
                iter_result0 (IterationResult): The result of the first runner.
                iter_result1 (IterationResult): The result of the second runner.

            Returns:
                OrderedDict[str, OutputCompareResult]:
                        The name of the outputs compared, derived from the first IterationResult,
                        and whether they matched. If an output name is not found, it is omitted from this dictionary.

            Raises:
                PolygraphyException: If all output names are skipped, and thus no outputs are compared.
            """
            def check_dict(dct, dict_name):
                if isinstance(dct, dict):
                    util.check_dict_contains(
                        dct,
                        set(iter_result0.keys()) | set(iter_result1.keys())
                        | {""},
                        check_missing=False,
                        dict_name=dict_name,
                    )

            check_dict(rtol, "the rtol dictionary")
            check_dict(atol, "the atol dictionary")
            check_dict(check_error_stat, "the check_error_stat dictionary")

            output_status = OrderedDict(
            )  # OrderedDict[str, bool] Maps output names to whether they matched.

            if not check_shapes:
                G_LOGGER.info(
                    "Strict shape checking disabled. Will attempt to match output shapes before comparisons"
                )

            def default_find_output_func(output_name, index, iter_result):
                found_name = util.find_in_dict(output_name, iter_result, index)
                if found_name is None:
                    return None
                elif found_name != output_name:
                    exact_match = util.find_in_dict(found_name, iter_result0)
                    if exact_match == found_name:
                        G_LOGGER.verbose(
                            "Will not compare {:} with {:}, since the former already has an exact match: {:}"
                            .format(found_name, output_name, exact_match))
                        return None  # If the found output is being compared against another output already, skip this non-exact match
                    G_LOGGER.warning(
                        "Output names did not match exactly. Assuming {:} output: {:} "
                        "corresponds to output: {:}".format(
                            iter_result.runner_name, found_name, output_name))
                return [found_name]

            nonlocal find_output_func
            find_output_func = util.default(find_output_func,
                                            default_find_output_func)

            for index, (out0_name, output0) in enumerate(iter_result0.items()):
                out1_names = util.default(
                    find_output_func(out0_name, index, iter_result1), [])

                if len(out1_names) > 1:
                    G_LOGGER.info(
                        "Will attempt to compare output: '{:}' [{:}] with multiple outputs: '{:}' [{:}]"
                        .format(out0_name, iter_result0.runner_name,
                                list(out1_names), iter_result1.runner_name))

                for out1_name in out1_names:
                    if out1_name is None or out1_name not in iter_result1:
                        G_LOGGER.warning(
                            "For output: '{:}' [{:}], skipping corresponding output: '{:}' [{:}], "
                            "since the output was not found".format(
                                out0_name, iter_result0.runner_name, out1_name,
                                iter_result1.runner_name))
                        continue

                    per_out_atol = util.value_or_from_dict(
                        atol, out0_name, default_atol)
                    per_out_rtol = util.value_or_from_dict(
                        rtol, out0_name, default_rtol)
                    per_out_err_stat = util.value_or_from_dict(
                        check_error_stat, out0_name, default_error_stat)

                    output1 = iter_result1[out1_name]
                    G_LOGGER.start(
                        "Comparing Output: '{:}' (dtype={:}, shape={:}) with '{:}' (dtype={:}, shape={:}) | "
                        "Tolerance: [abs={:.5g}, rel={:.5g}] | Checking {:} error"
                        .format(
                            out0_name,
                            output0.dtype,
                            output0.shape,
                            out1_name,
                            output1.dtype,
                            output1.shape,
                            per_out_atol,
                            per_out_rtol,
                            per_out_err_stat,
                        ))
                    G_LOGGER.extra_verbose(
                        "Note: Comparing {:} vs. {:}".format(
                            iter_result0.runner_name,
                            iter_result1.runner_name))

                    with G_LOGGER.indent():
                        if check_shapes and output0.shape != output1.shape:
                            G_LOGGER.error(
                                "Will not compare outputs of different shapes. Note: Output shapes are "
                                "{:} and {:}.".format(output0.shape,
                                                      output1.shape))
                            G_LOGGER.error(
                                "Note: Use --no-shape-check or set check_shapes=False to "
                                "attempt to compare values anyway.",
                                mode=LogMode.ONCE,
                            )
                            outputs_match = False
                        else:
                            output1 = util.try_match_shape(
                                output1, output0.shape)
                            output0 = output0.reshape(output1.shape)
                            outputs_match = check_outputs_match(
                                output0,
                                out0_name,
                                output1,
                                out1_name,
                                per_out_rtol=per_out_rtol,
                                per_out_atol=per_out_atol,
                                per_out_err_stat=per_out_err_stat,
                                runner0_name=iter_result0.runner_name,
                                runner1_name=iter_result1.runner_name,
                            )

                        output_status[out0_name] = outputs_match
                        if fail_fast and not outputs_match:
                            return output_status

            mismatched_output_names = [
                name for name, matched in output_status.items() if not matched
            ]
            if mismatched_output_names:
                G_LOGGER.error("FAILED | Mismatched outputs: {:}".format(
                    mismatched_output_names))
            else:
                G_LOGGER.finish(
                    "PASSED | All outputs matched | Outputs: {:}".format(
                        list(output_status.keys())))

            # This is useful for catching cases were Polygraphy does something wrong with the runner output buffers
            if not output_status and (bool(iter_result0.keys())
                                      or bool(iter_result1.keys())):
                r0_name = iter_result0.runner_name
                r0_outs = list(iter_result0.keys())
                r1_name = iter_result1.runner_name
                r1_outs = list(iter_result1.keys())
                G_LOGGER.critical(
                    "All outputs were skipped, no common outputs found! Note:\n{:} outputs: "
                    "{:}\n{:} outputs: {:}".format(r0_name, r0_outs, r1_name,
                                                   r1_outs))

            return output_status
Exemplo n.º 4
0
def test_shape_matching(case):
    out, shape, expected_shape = case
    out = util.try_match_shape(out, shape)
    assert out.shape == expected_shape
Exemplo n.º 5
0
def test_shape_matching(arr, shape, expected):
    arr = util.try_match_shape(arr, shape)
    assert np.array_equal(arr, expected)