Exemple #1
0
def compare_traces(ref_trace: Trace,
                   tar_trace: Trace) -> Tuple[bool, Sequence[str]]:
    traces_match = True
    error_messages = []

    # Check that all method invocations match.
    ref_methods = [(call.method, call.rtol, call.atol) for call in ref_trace]
    tar_methods = [(call.method, call.rtol, call.atol) for call in tar_trace]
    if ref_methods != tar_methods:
        # Raise a ValueError instead of returning False since this is an
        # unexpected error.
        raise ValueError(
            "The reference and target traces have different call structures:\n"
            f"Reference: {ref_methods}\nTarget:    {tar_methods}")

    for ref_call, tar_call in zip(ref_trace, tar_trace):
        logging.info("Comparing calls to '%s'", ref_call.method)
        rtol, atol = ref_call.get_tolerances()

        inputs_match, error_message = tf_utils.check_same(
            ref_call.inputs, tar_call.inputs, rtol, atol)
        if not inputs_match:
            error_messages.append(error_message)
            logging.error("Inputs did not match.")
        outputs_match, error_message = tf_utils.check_same(
            ref_call.outputs, tar_call.outputs, rtol, atol)
        if not outputs_match:
            error_messages.append(error_message)
            logging.error("Outputs did not match.")
        calls_match = inputs_match and outputs_match

        if not calls_match:
            logging.error(
                "Comparision between '%s' and '%s' failed on method '%s'",
                ref_trace.backend_id, tar_trace.backend_id, ref_call.method)
            logging.error("Reference call '%s':\n%s", ref_trace.backend_id,
                          ref_call)
            logging.error("Target call '%s':\n%s", tar_trace.backend_id,
                          tar_call)

        traces_match = traces_match and calls_match
    return traces_match, error_messages
Exemple #2
0
  def test_recursive_check_same(self, array_c, array_d, array_e, tar_same):

    # yapf: disable
    ref = {
        'a': 1,
        'b': [
            {'c': np.array([0, 1, 2])},
            {'d': np.array(['0', '1', '2'])},
            {'e': np.array([0.0, 0.1, 0.2])}
        ],
    }
    tar = {
        'a': 1,
        'b': [
            {'c': array_c},
            {'d': array_d},
            {'e': array_e}
        ],
    }
    # yapf: enable
    same, _ = tf_utils.check_same(ref, tar, rtol=1e-6, atol=1e-6)
    self.assertEqual(tar_same, same)