def test_failure_at_PrepareCompositeFunctionsPass(self): if context.is_tfrt_enabled(): self.skipTest('This test crashed with TFRT.') class NgramsLayer(tf.keras.layers.Layer): def call(self, input_tensor, **kwargs): return mock_ngrams(input_tensor, width=2, axis=-1, string_separator=' ') # Registers a fake WhitespaceTokenizeWithOffsets so the TFText fusing logic # is enable in MLIR side. custom_opdefs_str = ( 'name: \'WhitespaceTokenizeWithOffsets\' input_arg: {name: \'Input1\' ' 'type: DT_FLOAT} input_arg: {name: \'Input2\' type: DT_FLOAT} ' 'output_arg: {name: \'Output\' type: DT_FLOAT}') register_custom_opdefs([custom_opdefs_str]) model = tf.keras.models.Sequential([NgramsLayer()]) model.predict(tf.constant(['test'])) converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.allow_custom_ops = True self.convert_and_check_location_info( converter, converter_error_data_pb2.ConverterErrorData.UNKNOWNLOC) exported_error = metrics._gauge_conversion_errors.get_cell( 'CONVERT_TF_TO_TFLITE_MODEL', 'PrepareCompositeFunctionsPass', '', 'UNKNOWN').value() self.assertEqual(exported_error, "\'width\' attribute is not set or not an integer\n")
def dtensor_shutdown_tpu_system(): """Shutdown TPU system.""" @def_function.function def _shutdown_tpu_system(): return gen_dtensor_ops.shutdown_tpu_system() success = _shutdown_tpu_system() if context.is_tfrt_enabled() else True if success: logging.info("TPU system shut down.") else: logging.warning("TPU system fails to shut down.")
def __init__(self): # TODO(b/153054118): Add tf.RandomUniform if not context.is_tfrt_enabled(): # used for multiply benchmarks self._m_2 = random_ops.random_uniform([2]) # used for matmul benchmarks self._m_2_by_2 = random_ops.random_uniform((2, 2)) self._m_100_by_784 = random_ops.random_uniform((100, 784)) self._num_iters_2_by_2 = 30000 self._num_iters_100_by_784 = 30000
def testInvalidOutputTypeMatmul(self): for dtype in [dtypes.int8, dtypes.bfloat16]: a = constant_op.constant(np.arange(1, 13), shape=[2, 2, 3], dtype=dtype) b = constant_op.constant( np.arange(13, 25), shape=[2, 3, 2], dtype=dtypes.int8) if context.executing_eagerly(): if context.is_tfrt_enabled(): with self.assertRaisesRegex(errors.InvalidArgumentError, "NodeDef expected inputs"): math_ops.matmul(a, b, output_type=dtypes.float32) else: with self.assertRaisesRegex(errors.NotFoundError, "Could not find device for node:"): math_ops.matmul(a, b, output_type=dtypes.float32) else: with self.assertRaisesRegex(errors.InvalidArgumentError, "No OpKernel was registered to support Op"): self.evaluate(math_ops.matmul(a, b, output_type=dtypes.float32))
def _get_benchmark_name(self): """Mostly copied from benchmark.py _get_name().""" stack = tf_inspect.stack() name = None for frame in stack[::-1]: f_locals = frame[0].f_locals f_self = f_locals.get("self", None) if isinstance(f_self, test.Benchmark): name = frame[3] # Get the method name # This is a hack to get around the fact that some methods might have a # disable_tfrt decorator around them. In that case a function called # 'decorated' wraps the real called function underneath and so we # peek one deeper into the stack to get the real name. if name == "decorated": continue else: break if name is None: raise ValueError("Unable to determine calling Benchmark function.") if context.is_tfrt_enabled(): name = name + "_tfrt" return name
def _configure_tpu_runtime(): was_enabled = context.is_tfrt_enabled() if ("tpu_use_tfrt" in flags.FLAGS and flags.FLAGS["tpu_use_tfrt"].value): tfrt_utils.set_tfrt_enabled(True) if not was_enabled: context._reset_context() # pylint:disable=protected-access