def test_strategy(self, strategy):
        if (
            backend.is_tpu_strategy(strategy)
            and not tf_test_utils.is_mlir_bridge_enabled()
        ):
            self.skipTest("TPU tests require MLIR bridge")

        input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]])
        inp_dataset = tf.data.Dataset.from_tensor_slices(input_array)
        inp_dataset = batch_wrapper(inp_dataset, 2, strategy)

        # pyformat: disable
        expected_output = [[0, 1, 1, 1, 0, 0], [1, 1, 0, 1, 0, 0]]
        # pyformat: enable
        num_tokens = 6
        tf.config.set_soft_device_placement(True)

        with strategy.scope():
            input_data = keras.Input(shape=(4,), dtype=tf.int32)
            layer = category_encoding.CategoryEncoding(
                num_tokens=num_tokens, output_mode=category_encoding.MULTI_HOT
            )
            int_data = layer(input_data)
            model = keras.Model(inputs=input_data, outputs=int_data)
        output_dataset = model.predict(inp_dataset)
        self.assertAllEqual(expected_output, output_dataset)
  def test_strategy_with_file(self, strategy):
    if (backend.is_tpu_strategy(strategy) and
        not tf_test_utils.is_mlir_bridge_enabled()):
      self.skipTest("TPU tests require MLIR bridge")

    vocab_data = ["earth", "wind", "and", "fire"]
    vocab_file = self._write_to_temp_file("temp", vocab_data)

    input_array = np.array([["earth", "wind", "and", "fire"],
                            ["fire", "and", "earth", "michigan"]])
    input_dataset = tf.data.Dataset.from_tensor_slices(input_array).batch(
        2, drop_remainder=True)
    expected_output = [[2, 3, 4, 5], [5, 4, 2, 1]]

    tf.config.set_soft_device_placement(True)

    with strategy.scope():
      input_data = keras.Input(shape=(None,), dtype=tf.string)
      layer = index_lookup.IndexLookup(
          max_tokens=None,
          num_oov_indices=1,
          mask_token="",
          oov_token="[OOV]",
          vocabulary_dtype=tf.string,
          vocabulary=vocab_file)
      int_data = layer(input_data)
      model = keras.Model(inputs=input_data, outputs=int_data)
    model.compile(loss="mse")
    output_dataset = model.predict(input_dataset)
    self.assertAllEqual(expected_output, output_dataset)
Beispiel #3
0
    def test_distribution_strategy_output(self, strategy):
        if (backend.is_tpu_strategy(strategy)
                and not tf_test_utils.is_mlir_bridge_enabled()):
            self.skipTest("TPU tests require MLIR bridge")

        vocab_data = ["earth", "wind", "and", "fire"]
        input_array = np.array([["earth", "wind", "and", "fire"],
                                ["fire", "and", "earth", "michigan"]])
        input_dataset = tf.data.Dataset.from_tensor_slices(input_array).batch(
            2, drop_remainder=True)

        expected_output = [[2, 3, 4, 5], [5, 4, 2, 1]]

        tf.config.set_soft_device_placement(True)

        with strategy.scope():
            input_data = keras.Input(shape=(None, ), dtype=tf.string)
            layer = text_vectorization.TextVectorization(
                max_tokens=None,
                standardize=None,
                split=None,
                output_mode=text_vectorization.INT,
                vocabulary=vocab_data)
            int_data = layer(input_data)
            model = keras.Model(inputs=input_data, outputs=int_data)

        output_dataset = model.predict(input_dataset)
        self.assertAllEqual(expected_output, output_dataset)
  def testV2SummaryWithKerasSubclassedModel(self):
    # Histogram summaries require the MLIR bridge; see b/178826597#comment107.
    # TODO(https://github.com/tensorflow/tensorboard/issues/2885): remove this
    #   if histogram summaries are supported fully on non-MLIR bridge or
    #   non-MLIR bridge is no longer run.
    enable_histograms = test_util.is_mlir_bridge_enabled()
    strategy = get_tpu_strategy()
    with strategy.scope():
      model = CustomModel(enable_histograms=enable_histograms)
      model.compile('sgd', 'mse')

      dataset = distribute_strategy_test.get_dataset(strategy)
      tensorboard_callback = callbacks.TensorBoard(
          self.summary_dir, update_freq=2)
      model.fit(
          dataset,
          steps_per_epoch=10,
          epochs=1,
          callbacks=[tensorboard_callback])

      event_files = tf.io.gfile.glob(
          os.path.join(self.summary_dir, 'train', 'event*'))
      # Since total of 10 steps are ran and summary ops should be invoked
      # every 2 batches, we should see total of 5 event logs for each summary.
      expected_event_counts = {
          ('custom_model/layer_for_scalar_summary/'
           'custom_scalar_summary_v2'):
              5,
          ('custom_model/layer_for_histogram_summary/'
           'custom_histogram_summary_v2'):
              5 if enable_histograms else 0,
      }
      self.validate_recorded_sumary_file(event_files, expected_event_counts)
Beispiel #5
0
    def testV2SummaryWithKerasSequentialModel(self):
        # Histogram summaries require the MLIR bridge; see b/178826597#comment107.
        # TODO(https://github.com/tensorflow/tensorboard/issues/2885): remove this
        #   if histogram summaries are supported fully on non-MLIR bridge or
        #   non-MLIR bridge is no longer run.
        enable_histograms = tf_test_utils.is_mlir_bridge_enabled()
        strategy = get_tpu_strategy()

        with strategy.scope():
            model = mnist_model((28, 28, 3),
                                enable_histograms=enable_histograms)
            model.compile("sgd", "mse")

            dataset = get_image_dataset()
            tensorboard_callback = callbacks.TensorBoard(self.summary_dir,
                                                         update_freq=2)
            model.fit(
                dataset,
                steps_per_epoch=10,
                epochs=1,
                callbacks=[tensorboard_callback],
            )

            event_files = tf.io.gfile.glob(
                os.path.join(self.summary_dir, "train", "event*"))
            # Since total of 10 steps are ran and summary ops should be invoked
            # every 2 batches, we should see total of 5 event logs for each summary.
            expected_event_counts = {
                "sequential/layer_for_histogram_summary/custom_histogram_summary_v2":
                5 if enable_histograms else 0,
                "sequential/layer_for_image_summary/custom_image_summary_v2":
                5,
            }
            self.validate_recorded_sumary_file(event_files,
                                               expected_event_counts)
  def test_spmd_with_summary(self):
    if test_util.is_mlir_bridge_enabled():
      self.skipTest("TODO(b/232580663): fix MLIR bridge")
    original_device_placement = config.get_soft_device_placement()
    config.set_soft_device_placement(True)

    strategy, _ = get_tpu_strategy(enable_spmd=True)
    summary_dir = self.get_temp_dir()
    writer = summary_ops.create_file_writer_v2(summary_dir)

    with strategy.scope():
      step = variables.Variable(0, dtype=dtypes.int64)

    @def_function.function
    def run():
      with writer.as_default():
        summary_ops.scalar("result", step * 2, step=step)
        step.assign_add(1)

    for _ in range(10):
      strategy.run(run, args=())

    for val in step.values:
      for var in val.variables:
        self.assertAllEqual(10, var)

    config.set_soft_device_placement(original_device_placement)
  def test_paritioned_model_checkpointing(self):
    if test_util.is_mlir_bridge_enabled():
      self.skipTest("TODO(b/238811067): fix MLIR bridge")

    class PartitionedModel(module.Module):

      def __init__(self, v, w):
        super(PartitionedModel, self).__init__()

        assert distribution_strategy_context.has_strategy()
        strategy = distribution_strategy_context.get_strategy()

        with strategy.extended.experimental_logical_device(0):
          self.v = variables.Variable(v)
        with strategy.extended.experimental_logical_device(1):
          self.w = variables.Variable(w)

      def __call__(self, x):
        replica_ctx = distribution_strategy_context.get_replica_context()
        with replica_ctx.experimental_logical_device(0):
          y = self.v * x
        with replica_ctx.experimental_logical_device(1):
          z = self.w * y
        return z

      def change_weights_op(self, v_new, w_new):
        return control_flow_ops.group(
            [self.v.assign(v_new), self.w.assign(w_new)])

    strategy, num_replicas = get_tpu_strategy()
    with strategy.scope():
      model = PartitionedModel(2., 3.)

    checkpoint_dir = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    checkpoint = util.Checkpoint(model=model)

    with self.cached_session() as sess:
      self.evaluate(variables.global_variables_initializer())
      checkpoint.save(file_prefix=checkpoint_prefix)

      self.evaluate(model.change_weights_op(1., 4.))
      result = strategy.run(def_function.function(model), args=(5.0,))
      self.assertEqual(20. * num_replicas,
                       self.evaluate(strategy.reduce("SUM", result, axis=None)))

      status = checkpoint.restore(
          checkpoint_management.latest_checkpoint(checkpoint_dir))
      status.run_restore_ops(sess)  # must run restore op in non-eager mode.
      status.assert_consumed()
      status.assert_existing_objects_matched()
      result = strategy.run(def_function.function(model), args=(5.0,))
      self.assertEqual(30. * num_replicas,
                       self.evaluate(strategy.reduce("SUM", result, axis=None)))
Beispiel #8
0
  def testUnsupportedOps(self):
    with ops.device('device:{}:0'.format(self.device)):

      def fn(x):
        return string_ops.string_length(
            string_ops.string_format('{}', x))

      xla_func = def_function.function(fn, jit_compile=True)

      with self.assertRaisesRegex(
          errors.InvalidArgumentError, 'legalization failed'
          if test_util.is_mlir_bridge_enabled() else 'unsupported operations'):
        xla_func(constant_op.constant([3.1, 3.2]))
Beispiel #9
0
    def testMethodCompilationUnsupportedFunc(self):
        with ops.device('device:{}:0'.format(self.device)):

            class C(object):
                @def_function.function(jit_compile=True)
                def f1(self, x):
                    return string_ops.string_length(
                        string_ops.string_format('{}', x))

            inputs = constant_op.constant([1, 2, 2, 3, 3])
            c = C()
            with self.assertRaisesRegex(
                    errors.InvalidArgumentError, 'legalization failed'
                    if test_util.is_mlir_bridge_enabled() else
                    'unsupported operations'):
                c.f1(inputs)
    def _testReduction(self,
                       tf_reduce_fn,
                       np_reduce_fn,
                       dtype,
                       test_inputs,
                       index_dtype,
                       rtol=1e-4,
                       atol=1e-4):
        """Tests that the output of 'tf_reduce_fn' matches numpy's output."""

        for test_input in test_inputs:
            with self.session() as sess:
                with self.test_scope():
                    a = array_ops.placeholder(dtype)
                    index = array_ops.placeholder(index_dtype)
                    out = tf_reduce_fn(a, index)
                result = sess.run(out, {a: test_input, index: [0]})
                self.assertAllClose(result,
                                    np_reduce_fn(test_input, axis=0),
                                    rtol=rtol,
                                    atol=atol)

                result = sess.run(out, {a: test_input, index: [1]})
                self.assertAllClose(result,
                                    np_reduce_fn(test_input, axis=1),
                                    rtol=rtol,
                                    atol=atol)

                result = sess.run(out, {a: test_input, index: [-1]})
                self.assertAllClose(result,
                                    np_reduce_fn(test_input, axis=1),
                                    rtol=rtol,
                                    atol=atol)

                # MLIR bridge doesn't return the same error so it can't be matched
                # directly.
                if not test_util.is_mlir_bridge_enabled():
                    with self.assertRaisesWithPredicateMatch(
                            errors_impl.InvalidArgumentError,
                            'Invalid reduction dim'):
                        sess.run(out, {a: test_input, index: [-33]})

                    with self.assertRaisesWithPredicateMatch(
                            errors_impl.InvalidArgumentError,
                            'Invalid reduction dim'):
                        sess.run(out, {a: test_input, index: [2]})
  def testNestedCallUnsupportedOps(self):
    with ops.device('device:{}:0'.format(self.device)):

      def fn(x):
        return array_ops.unique(x).y

      xla_func = def_function.function(fn, jit_compile=True)

      def fn2(x):
        return xla_func(x)

      func = def_function.function(fn2, jit_compile=False)
      inputs = constant_op.constant([1, 2, 2, 3, 3])
      with self.assertRaisesRegex(
          errors.InvalidArgumentError, 'legalization failed'
          if test_util.is_mlir_bridge_enabled() else 'not compilable'):
        func(inputs)
Beispiel #12
0
  def testUnsupportedOps(self):
    if 'tpu' in self.device.lower():
      self.skipTest('XLA TPU supports tf.unique')

    with ops.device('device:{}:0'.format(self.device)):

      def fn(x):
        return array_ops.unique(x).y  # Unique is not supported by XLA

      func = def_function.function(fn, jit_compile=False)
      xla_func = def_function.function(fn, jit_compile=True)

      inputs = constant_op.constant([1, 2, 2, 3, 3])
      self.assertAllClose([1, 2, 3], func(inputs))
      with self.assertRaisesRegex(
          errors.InvalidArgumentError, 'legalization failed'
          if test_util.is_mlir_bridge_enabled() else 'unsupported operations'):
        xla_func(inputs)
    def testMethodCompilationUnsupportedFunc(self):
        if 'tpu' in self.device.lower():
            self.skipTest('XLA TPU supports tf.unique')

        with ops.device('device:{}:0'.format(self.device)):

            class C(object):
                @def_function.function(jit_compile=True)
                def f1(self, x):
                    return array_ops.unique(x).y

            inputs = constant_op.constant([1, 2, 2, 3, 3])
            c = C()
            with self.assertRaisesRegex(
                    errors.InvalidArgumentError, 'legalization failed'
                    if test_util.is_mlir_bridge_enabled() else
                    'unsupported operations'):
                c.f1(inputs)
Beispiel #14
0
    def test_strategy(self, strategy):
        if (backend.is_tpu_strategy(strategy)
                and not tf_test_utils.is_mlir_bridge_enabled()):
            self.skipTest("TPU tests require MLIR bridge")

        input_data = np.asarray([["omar"], ["stringer"], ["marlo"], ["wire"]])
        input_dataset = tf.data.Dataset.from_tensor_slices(input_data).batch(
            2, drop_remainder=True)
        expected_output = [[0], [0], [1], [0]]

        tf.config.set_soft_device_placement(True)

        with strategy.scope():
            input_data = keras.Input(shape=(None, ), dtype=tf.string)
            layer = hashing.Hashing(num_bins=2)
            int_data = layer(input_data)
            model = keras.Model(inputs=input_data, outputs=int_data)
        output_dataset = model.predict(input_dataset)
        self.assertAllEqual(expected_output, output_dataset)
Beispiel #15
0
    def testCollectiveReduceGroupAssignment(self):
        if not test_util.is_mlir_bridge_enabled():
            self.skipTest('AssignGroup is only supported in the MLIR bridge.')

        with ops.device('device:{}:0'.format(self.device)):

            @def_function.function(jit_compile=True)
            def fn(x):
                group_key = collective_ops.assign_group_v2(
                    group_assignment=[[0]], device_index=0)
                t0 = collective_ops.all_reduce_v2(t=x,
                                                  group_size=1,
                                                  group_key=group_key,
                                                  instance_key=1)
                return t0

            inputs = constant_op.constant([1.0, 2.0, 3.0])
            # Make sure 2 different channel ids are assigned to the 2 all-reduce
            # instructions generated by XLA.
            hlo_str = fn.experimental_get_compiler_ir(inputs)()
            self.assertIn('replica_groups={{0}}', hlo_str)
Beispiel #16
0
  def testNestedCallUnsupportedOps(self):
    if 'tpu' in self.device.lower():
      self.skipTest('Outside compilation will extract string_length to CPU')

    with ops.device('device:{}:0'.format(self.device)):

      def fn(x):
        return string_ops.string_length(
            string_ops.string_format('{}', x))

      xla_func = def_function.function(fn, jit_compile=True)

      def fn2(x):
        return xla_func(x)

      func = def_function.function(fn2, jit_compile=False)
      inputs = constant_op.constant([1, 2, 2, 3, 3])
      with self.assertRaisesRegex(
          errors.InvalidArgumentError, 'legalization failed'
          if test_util.is_mlir_bridge_enabled() else 'unsupported operations'):
        func(inputs)
  def test_logical_device_assignment(self):
    if test_util.is_mlir_bridge_enabled():
      self.skipTest("TODO(b/238811067): fix MLIR bridge")
    strategy, num_replicas = get_tpu_strategy()
    with strategy.scope():
      v = variables.Variable(2.)
      with strategy.extended.experimental_logical_device(1):
        w = variables.Variable(3.)

    self.assertLen(strategy.experimental_local_results(v), num_replicas)
    self.assertLen(strategy.experimental_local_results(w), num_replicas)
    self.assertEqual("/job:localhost/replica:0/task:0/device:TPU:0",
                     strategy.experimental_local_results(v)[0].device)
    self.assertEqual("/job:localhost/replica:0/task:0/device:TPU:1",
                     strategy.experimental_local_results(w)[0].device)

    logical_devices = []

    @def_function.function
    def f(x):
      replica_ctx = distribution_strategy_context.get_replica_context()
      with replica_ctx.experimental_logical_device(0):
        y = v * x
      with replica_ctx.experimental_logical_device(1):
        z = w * y
      logical_devices.append((y.device, z.device))
      return z

    result = strategy.run(f, args=(5.,))

    self.assertEqual(
        [("/device:TPU_REPLICATED_CORE:0", "/device:TPU_REPLICATED_CORE:1")],
        logical_devices)

    with self.cached_session():
      self.evaluate(variables.global_variables_initializer())
      self.assertEqual(30. * num_replicas,
                       self.evaluate(strategy.reduce("SUM", result, axis=None)))
    def testTensorArrayErrorMessage(self):
        with ops.device('device:{}:0'.format(self.device)):

            @def_function.function(jit_compile=True)
            def f():
                # The error message as old and new bridge differ in which op they flag.
                # The one points to the creation of the unitialized tensor array, the
                # other is the use of the unitialized tensor array.
                ta = tensor_array_ops.TensorArray(  # EXPECTED_MESSAGE_NEW
                    dtype=dtypes.float32,
                    size=2,
                    dynamic_size=True,
                    element_shape=(None, ))
                return ta.concat()  # EXPECTED_MESSAGE_OLD

            if test_util.is_mlir_bridge_enabled():
                with self.assertRaisesRegex(errors.InvalidArgumentError,
                                            'EXPECTED_MESSAGE_NEW'):
                    f()
            else:
                with self.assertRaisesRegex(errors.InvalidArgumentError,
                                            'EXPECTED_MESSAGE_OLD'):
                    f()
Beispiel #19
0
    def __init__(self, method_name='runTest'):
        super(XLATestCase, self).__init__(method_name)
        context.context(
        ).enable_mlir_bridge = test_util.is_mlir_bridge_enabled()

        self.device = FLAGS.test_device
        self.has_custom_call = (self.device == 'XLA_CPU')
        self._all_tf_types = set([
            dtypes.as_dtype(types_pb2.DataType.Value(name))
            for name in FLAGS.types.split(',')
        ])
        self.int_tf_types = set(
            [dtype for dtype in self._all_tf_types if dtype.is_integer])
        self._float_tf_types = set(
            [dtype for dtype in self._all_tf_types if dtype.is_floating])
        self.complex_tf_types = set(
            [dtype for dtype in self._all_tf_types if dtype.is_complex])
        self._numeric_tf_types = set(self.int_tf_types | self._float_tf_types
                                     | self.complex_tf_types)
        self.quantized_tf_types = set(dtype for dtype in self._all_tf_types
                                      if dtype.is_quantized)

        # Quantized types don't have a numpy equivalent, include them in
        # all_tf_types but not in all_types.
        # TODO(b/115960798): Parametrize tests on TF types instead of numpy types
        # and remove all_types.
        self._all_types = set(dtype.as_numpy_dtype
                              for dtype in self._all_tf_types
                              if not dtype.is_quantized)
        self._int_types = set(
            [dtype.as_numpy_dtype for dtype in self.int_tf_types])
        self.signed_int_types = set(dtype.as_numpy_dtype
                                    for dtype in self.int_tf_types
                                    if not dtype.is_unsigned)
        self.unsigned_int_types = set(dtype.as_numpy_dtype
                                      for dtype in self.int_tf_types
                                      if dtype.is_unsigned)
        self._float_types = set(
            [dtype.as_numpy_dtype for dtype in self._float_tf_types])
        self.complex_types = set(
            [dtype.as_numpy_dtype for dtype in self.complex_tf_types])
        self._numeric_types = set(self._int_types | self._float_types
                                  | self.complex_types)

        # Parse the manifest file, if any, into a regex identifying tests to
        # disable
        # TODO(xpan): Make it text proto if it doesn't scale.
        # Each line of the manifest file specifies an entry. The entry can be
        # 1) TestNameRegex  // E.g. CumprodTest.* Or
        # 2) TestName TypeName  // E.g. AdamOptimizerTest.testSharing DT_BFLOAT16
        # The 1) disables the entire test. While 2) only filter some numeric types
        # so that they are not used in those tests.
        self.disabled_regex = None
        self._method_types_filter = {}

        if FLAGS.disabled_manifest is not None:
            with open(FLAGS.disabled_manifest, 'r') as manifest_file:
                disabled_regex, self._method_types_filter = (
                    parse_disabled_manifest(manifest_file.read()))
                if disabled_regex:
                    self.disabled_regex = re.compile(disabled_regex)

        if FLAGS.tf_xla_flags is not None:
            os.environ['TF_XLA_FLAGS'] = FLAGS.tf_xla_flags
Beispiel #20
0
    def testGradientTraining(self, data_format):
        # disable_mlir_bridge for GPUs as there is no legalization for GPU with
        # MLIR.
        # TODO(b/189039456): Customize FusedBatchNorm legalization for GPU in MLIR.
        if test_util.is_mlir_bridge_enabled() and self.device == "XLA_GPU":
            self.skipTest("b/189039456")

        # TODO(b/64270657): Use gradient_checker here in addition to comparing with
        # this reference implementation.
        channel = 3
        x_shape = [2, 2, 6, channel]
        scale_shape = [channel]
        grad_val = np.random.random_sample(x_shape).astype(np.float32)
        x_val = np.random.random_sample(x_shape).astype(np.float32)
        scale_val = np.random.random_sample(scale_shape).astype(np.float32)
        mean_val = np.random.random_sample(scale_shape).astype(np.float32)
        var_val = np.random.random_sample(scale_shape).astype(np.float32)
        epsilon = 0.001

        # The TensorFlow FusedBatchNormGrad training operation takes two inputs with
        # implementation defined values.  In theory the only correct value these
        # inputs are the corresponding reserve_space_{1|2} outputs from the
        # FusedBatchNorm training operation.  However, in practice, we rely on the
        # first one being mean on {C|G}PU, and the second one being variance on CPU
        # and inverse(sqrt(variance + epsilon)) on GPU (we test this assumption
        # separately).
        reserve_space_1_val = mean_val
        if self.device == "XLA_GPU":
            reserve_space_2_val = np.reciprocal(np.sqrt(var_val + epsilon))
        else:
            reserve_space_2_val = var_val

        data_format_src = "NHWC"
        grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad(
            x_val, grad_val, scale_val, mean_val, var_val, epsilon,
            data_format_src)

        with self.session() as sess, self.test_scope():
            grad_val_converted = test_utils.ConvertBetweenDataFormats(
                grad_val, data_format_src, data_format)
            x_val_converted = test_utils.ConvertBetweenDataFormats(
                x_val, data_format_src, data_format)
            grad_x_ref_converted = test_utils.ConvertBetweenDataFormats(
                grad_x_ref, data_format_src, data_format)

            grad = array_ops.placeholder(np.float32,
                                         shape=x_val_converted.shape,
                                         name="grad")
            x = array_ops.placeholder(np.float32,
                                      shape=x_val_converted.shape,
                                      name="x")
            reserve_space_1 = array_ops.placeholder(np.float32,
                                                    shape=scale_shape,
                                                    name="reserve_space_1")
            reserve_space_2 = array_ops.placeholder(np.float32,
                                                    shape=scale_shape,
                                                    name="reserve_space_2")
            scale = array_ops.placeholder(np.float32,
                                          shape=scale_shape,
                                          name="scale")
            grad_x, grad_scale, grad_offset, _, _ = gen_nn_ops.fused_batch_norm_grad(
                grad,
                x,
                scale,
                reserve_space_1,
                reserve_space_2,
                data_format=data_format,
                is_training=True)

            grad_x_val, grad_scale_val, grad_offset_val = sess.run(
                [grad_x, grad_scale, grad_offset], {
                    grad: grad_val_converted,
                    x: x_val_converted,
                    reserve_space_1: reserve_space_1_val,
                    reserve_space_2: reserve_space_2_val,
                    scale: scale_val
                })

            self.assertAllClose(grad_x_val, grad_x_ref_converted, atol=1e-2)
            self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2)
            self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3)
Beispiel #21
0
 def setUp(self):
     super().setUp()
     self.rewrite_ops_for_tpu = ("TPU" in self.device
                                 and test_util.is_mlir_bridge_enabled())
Beispiel #22
0
    def __init__(self, method_name='runTest'):
        super(XLATestCase, self).__init__(method_name)
        if 'XLA' in FLAGS.test_device:
            context.context().enable_xla_devices()

        # Check if the mlir bridge has been explicitly enabled or disabled. If
        # is_mlir_bridge_enabled() returns None, the user did not explictly enable
        # or disable the bridge so do not update enable_mlir_bridge.
        if test_util.is_mlir_bridge_enabled():
            context.context().enable_mlir_bridge = True
        elif test_util.is_mlir_bridge_enabled() is not None:
            context.context().enable_mlir_bridge = False

        self.device = FLAGS.test_device
        self.has_custom_call = (self.device == 'XLA_CPU')

        # Some tests (e.g. ftrl_ops) only work if the program goes through the
        # _TPUCompileMLIR op. They will set this flag to True.
        # TODO(kramm): Flip to true (and enable MLIR bridge) for more tests.
        self.rewrite_ops_for_tpu = False

        self._all_tf_types = set([
            dtypes.as_dtype(types_pb2.DataType.Value(name))
            for name in FLAGS.types.split(',')
        ])
        self.int_tf_types = set(
            [dtype for dtype in self._all_tf_types if dtype.is_integer])
        self._float_tf_types = set(
            [dtype for dtype in self._all_tf_types if dtype.is_floating])
        self.complex_tf_types = set(
            [dtype for dtype in self._all_tf_types if dtype.is_complex])
        self._numeric_tf_types = set(self.int_tf_types | self._float_tf_types
                                     | self.complex_tf_types)
        self.quantized_tf_types = set(dtype for dtype in self._all_tf_types
                                      if dtype.is_quantized)

        # Quantized types don't have a numpy equivalent, include them in
        # all_tf_types but not in all_types.
        # TODO(b/115960798): Parametrize tests on TF types instead of numpy types
        # and remove all_types.
        self._all_types = set(dtype.as_numpy_dtype
                              for dtype in self._all_tf_types
                              if not dtype.is_quantized)
        self._int_types = set(
            [dtype.as_numpy_dtype for dtype in self.int_tf_types])
        self.signed_int_types = set(dtype.as_numpy_dtype
                                    for dtype in self.int_tf_types
                                    if not dtype.is_unsigned)
        self.unsigned_int_types = set(dtype.as_numpy_dtype
                                      for dtype in self.int_tf_types
                                      if dtype.is_unsigned)
        self._float_types = set(
            [dtype.as_numpy_dtype for dtype in self._float_tf_types])
        self.complex_types = set(
            [dtype.as_numpy_dtype for dtype in self.complex_tf_types])
        self._numeric_types = set(self._int_types | self._float_types
                                  | self.complex_types)

        # Parse the manifest file, if any, into a regex identifying tests to
        # disable
        # TODO(xpan): Make it text proto if it doesn't scale.
        # Each line of the manifest file specifies an entry. The entry can be
        # 1) TestNameRegex  // E.g. CumprodTest.* Or
        # 2) TestName TypeName  // E.g. AdamOptimizerTest.testSharing DT_BFLOAT16
        # The 1) disables the entire test. While 2) only filter some numeric types
        # so that they are not used in those tests.
        self.disabled_regex = None
        self._method_types_filter = {}

        if FLAGS.disabled_manifest is not None:
            with open(FLAGS.disabled_manifest, 'r') as manifest_file:
                disabled_regex, self._method_types_filter = (
                    parse_disabled_manifest(manifest_file.read()))
                if disabled_regex:
                    self.disabled_regex = re.compile(disabled_regex)

        if FLAGS.tf_xla_flags is not None:
            os.environ['TF_XLA_FLAGS'] = FLAGS.tf_xla_flags