예제 #1
0
    def test_failure_at_PrepareCompositeFunctionsPass(self):
        if context.is_tfrt_enabled():
            self.skipTest('This test crashed with TFRT.')

        class NgramsLayer(tf.keras.layers.Layer):
            def call(self, input_tensor, **kwargs):
                return mock_ngrams(input_tensor,
                                   width=2,
                                   axis=-1,
                                   string_separator=' ')

        # Registers a fake WhitespaceTokenizeWithOffsets so the TFText fusing logic
        # is enable in MLIR side.
        custom_opdefs_str = (
            'name: \'WhitespaceTokenizeWithOffsets\' input_arg: {name: \'Input1\' '
            'type: DT_FLOAT} input_arg: {name: \'Input2\' type: DT_FLOAT} '
            'output_arg: {name: \'Output\' type: DT_FLOAT}')
        register_custom_opdefs([custom_opdefs_str])

        model = tf.keras.models.Sequential([NgramsLayer()])
        model.predict(tf.constant(['test']))
        converter = tf.lite.TFLiteConverter.from_keras_model(model)
        converter.allow_custom_ops = True
        self.convert_and_check_location_info(
            converter, converter_error_data_pb2.ConverterErrorData.UNKNOWNLOC)
        exported_error = metrics._gauge_conversion_errors.get_cell(
            'CONVERT_TF_TO_TFLITE_MODEL', 'PrepareCompositeFunctionsPass', '',
            'UNKNOWN').value()
        self.assertEqual(exported_error,
                         "\'width\' attribute is not set or not an integer\n")
예제 #2
0
def dtensor_shutdown_tpu_system():
    """Shutdown TPU system."""
    @def_function.function
    def _shutdown_tpu_system():
        return gen_dtensor_ops.shutdown_tpu_system()

    success = _shutdown_tpu_system() if context.is_tfrt_enabled() else True
    if success:
        logging.info("TPU system shut down.")
    else:
        logging.warning("TPU system fails to shut down.")
예제 #3
0
    def __init__(self):
        # TODO(b/153054118): Add tf.RandomUniform
        if not context.is_tfrt_enabled():
            # used for multiply benchmarks
            self._m_2 = random_ops.random_uniform([2])

            # used for matmul benchmarks
            self._m_2_by_2 = random_ops.random_uniform((2, 2))
            self._m_100_by_784 = random_ops.random_uniform((100, 784))

        self._num_iters_2_by_2 = 30000
        self._num_iters_100_by_784 = 30000
예제 #4
0
 def testInvalidOutputTypeMatmul(self):
   for dtype in [dtypes.int8, dtypes.bfloat16]:
     a = constant_op.constant(np.arange(1, 13), shape=[2, 2, 3], dtype=dtype)
     b = constant_op.constant(
         np.arange(13, 25), shape=[2, 3, 2], dtype=dtypes.int8)
     if context.executing_eagerly():
       if context.is_tfrt_enabled():
         with self.assertRaisesRegex(errors.InvalidArgumentError,
                                     "NodeDef expected inputs"):
           math_ops.matmul(a, b, output_type=dtypes.float32)
       else:
         with self.assertRaisesRegex(errors.NotFoundError,
                                     "Could not find device for node:"):
           math_ops.matmul(a, b, output_type=dtypes.float32)
     else:
       with self.assertRaisesRegex(errors.InvalidArgumentError,
                                   "No OpKernel was registered to support Op"):
         self.evaluate(math_ops.matmul(a, b, output_type=dtypes.float32))
예제 #5
0
 def _get_benchmark_name(self):
   """Mostly copied from benchmark.py _get_name()."""
   stack = tf_inspect.stack()
   name = None
   for frame in stack[::-1]:
     f_locals = frame[0].f_locals
     f_self = f_locals.get("self", None)
     if isinstance(f_self, test.Benchmark):
       name = frame[3]  # Get the method name
       # This is a hack to get around the fact that some methods might have a
       # disable_tfrt decorator around them. In that case a function called
       # 'decorated' wraps the real called function underneath and so we
       # peek one deeper into the stack to get the real name.
       if name == "decorated":
         continue
       else:
         break
   if name is None:
     raise ValueError("Unable to determine calling Benchmark function.")
   if context.is_tfrt_enabled():
     name = name + "_tfrt"
   return name
예제 #6
0
def _configure_tpu_runtime():
    was_enabled = context.is_tfrt_enabled()
    if ("tpu_use_tfrt" in flags.FLAGS and flags.FLAGS["tpu_use_tfrt"].value):
        tfrt_utils.set_tfrt_enabled(True)
    if not was_enabled:
        context._reset_context()  # pylint:disable=protected-access