def compare_traces(ref_trace: Trace, tar_trace: Trace) -> Tuple[bool, Sequence[str]]: traces_match = True error_messages = [] # Check that all method invocations match. ref_methods = [(call.method, call.rtol, call.atol) for call in ref_trace] tar_methods = [(call.method, call.rtol, call.atol) for call in tar_trace] if ref_methods != tar_methods: # Raise a ValueError instead of returning False since this is an # unexpected error. raise ValueError( "The reference and target traces have different call structures:\n" f"Reference: {ref_methods}\nTarget: {tar_methods}") for ref_call, tar_call in zip(ref_trace, tar_trace): logging.info("Comparing calls to '%s'", ref_call.method) rtol, atol = ref_call.get_tolerances() inputs_match, error_message = tf_utils.check_same( ref_call.inputs, tar_call.inputs, rtol, atol) if not inputs_match: error_messages.append(error_message) logging.error("Inputs did not match.") outputs_match, error_message = tf_utils.check_same( ref_call.outputs, tar_call.outputs, rtol, atol) if not outputs_match: error_messages.append(error_message) logging.error("Outputs did not match.") calls_match = inputs_match and outputs_match if not calls_match: logging.error( "Comparision between '%s' and '%s' failed on method '%s'", ref_trace.backend_id, tar_trace.backend_id, ref_call.method) logging.error("Reference call '%s':\n%s", ref_trace.backend_id, ref_call) logging.error("Target call '%s':\n%s", tar_trace.backend_id, tar_call) traces_match = traces_match and calls_match return traces_match, error_messages
def test_recursive_check_same(self, array_c, array_d, array_e, tar_same): # yapf: disable ref = { 'a': 1, 'b': [ {'c': np.array([0, 1, 2])}, {'d': np.array(['0', '1', '2'])}, {'e': np.array([0.0, 0.1, 0.2])} ], } tar = { 'a': 1, 'b': [ {'c': array_c}, {'d': array_d}, {'e': array_e} ], } # yapf: enable same, _ = tf_utils.check_same(ref, tar, rtol=1e-6, atol=1e-6) self.assertEqual(tar_same, same)