def testCheckNumericsV2PositiveInfAndNaN(self): """Test that CheckNumericsV2 op shows sign of inf when nan is present.""" with self.session(graph=ops.Graph()): t1 = constant_op.constant([0.0, 1.0]) t2 = constant_op.constant([0.0, 0.0]) with self.assertRaisesRegex( errors.InvalidArgumentError, r"pass through test.*had \+Inf and NaN values"): self.evaluate( array_ops.check_numerics_v2(t1 / t2, message="pass through test"))
def testCheckNumericsV2OpNegativeAndPositiveInfAndNaN(self): """CheckNumericsV2 op distinguishes - & + infs when nan is present.""" with self.session(graph=ops.Graph()): t1 = constant_op.constant([-1.0, 1.0, 0.0]) t2 = constant_op.constant([0.0, 0.0, 0.0]) with self.assertRaisesRegex( errors.InvalidArgumentError, r"pass through test.*had -Inf, \+Inf, and NaN values"): self.evaluate( array_ops.check_numerics_v2(t1 / t2, message="pass through test"))
def testCheckNumericsV2OpNegativeAndPositiveInf(self): """Test that CheckNumericsV2 op distinguishes negative and positive infs.""" with self.session(graph=ops.Graph()): t1 = constant_op.constant([-1.0, 1.0]) t2 = constant_op.constant([0.0, 0.0]) with self.assertRaisesRegex( errors.InvalidArgumentError, r"pass through test.*had -Inf and \+Inf values"): self.evaluate( array_ops.check_numerics_v2(t1 / t2, message="pass through test"))
def testCheckNumericsV2PositiveInfAndNaN(self): """Test that CheckNumericsV2 op shows sign of inf when nan is present.""" with self.session(graph=ops.Graph()): t1 = constant_op.constant([0.0, 1.0]) t2 = constant_op.constant([0.0, 0.0]) checked = array_ops.check_numerics_v2(t1 / t2, message="pass through test") caught = None try: self.evaluate(checked) except errors.InvalidArgumentError as error: caught = error self.assertIn("had +Inf and NaN values", caught.message) self.assertIn("pass through test", caught.message)
def testCheckNumericsV2OpNegativeAndPositiveInfAndNaN(self): """CheckNumericsV2 op distinguishes - & + infs when nan is present.""" with self.session(graph=ops.Graph()): t1 = constant_op.constant([-1.0, 1.0, 0.0]) t2 = constant_op.constant([0.0, 0.0, 0.0]) checked = array_ops.check_numerics_v2(t1 / t2, message="pass through test") caught = None try: self.evaluate(checked) except errors.InvalidArgumentError as error: caught = error self.assertIn("had -Inf, +Inf, and NaN values", caught.message) self.assertIn("pass through test", caught.message)
def _CheckNumericsV2Grad(op, grad): """Gradient for check_numerics op.""" return array_ops.check_numerics_v2( grad, "Not a number (NaN) or infinity (Inf) values detected in gradient. %s" % op.get_attr("message"))
def callback(self, op_type, inputs, attrs, outputs, op_name=None, graph=None): """Eager-function unified callback for checking numerics.""" del attrs, op_name # Unused op_type_bytes = compat.as_bytes(op_type) is_v1_graph_mode = not ops.executing_eagerly_outside_functions() if (op_type_bytes in op_callbacks_common.OP_CALLBACK_SKIP_OPS or op_type_bytes in SAFE_OPS): return None if graph: # Under graph mode. Insert check_numerics op. instrumented_outputs = [] if is_v1_graph_mode: for input_tensor in inputs: if input_tensor in self._placeholder_to_debug_tensor and outputs: outputs[0].op._add_control_input( # pylint: disable=protected-access self._placeholder_to_debug_tensor[input_tensor].op) for slot, output in enumerate(outputs): if (output.dtype.is_floating and (op_type_bytes, slot) not in IGNORE_OP_OUTPUTS): checked_output = array_ops.check_numerics_v2( # TF v2 has automatic control dependencies added to stateful async # ops, which allows us to run check_numerics asynchronously. # In the above case we use debug_summary to reduce all output # tensors asynchronously from the op being checked and then # process the tensor summary with check_numerics. output if is_v1_graph_mode else _debug_summary(output), get_check_numerics_error_message( slot, len(outputs), op_type, output, inputs, graph=graph, traceback=output.op.traceback)) _CHECK_NUMERICS_INPUT_LOOKUP[graph][ checked_output.name] = output instrumented_outputs.append( self._get_output_tensor(op_type_bytes, output, checked_output, is_v1_graph_mode)) else: instrumented_outputs.append(output) return instrumented_outputs else: if op_type_bytes == b"CheckNumericsV2": # TODO(b/140334369): Remove this special casing logic once op_callback. # automatically prevents infinite recursion in eager mode. return None # Under eager mode. Eagerly execute check_numerics op. for slot, output in enumerate(outputs): if (output.dtype.is_floating and (op_type_bytes, slot) not in IGNORE_OP_OUTPUTS): array_ops.check_numerics_v2( output, get_check_numerics_error_message( slot, len(outputs), op_type, output, inputs, stack_height_limit=self._stack_height_limit, path_length_limit=self._path_length_limit))