def test_distribution_strategy_output_with_adapt(self, strategy):
        # TODO(b/180614455): remove this check when MLIR bridge is always enabled.
        if backend.is_tpu_strategy(strategy):
            self.skipTest("This test needs MLIR bridge on TPU.")

        vocab_data = [[
            "earth", "earth", "earth", "earth", "wind", "wind", "wind", "and",
            "and", "fire"
        ]]
        vocab_dataset = dataset_ops.Dataset.from_tensors(vocab_data)
        input_array = np.array([["earth", "wind", "and", "fire"],
                                ["fire", "and", "earth", "michigan"]])
        input_dataset = dataset_ops.Dataset.from_tensor_slices(
            input_array).batch(2, drop_remainder=True)

        expected_output = [[2, 3, 4, 5], [5, 4, 2, 1]]

        config.set_soft_device_placement(True)

        with strategy.scope():
            input_data = keras.Input(shape=(None, ), dtype=dtypes.string)
            layer = text_vectorization.TextVectorization(
                max_tokens=None,
                standardize=None,
                split=None,
                output_mode=text_vectorization.INT)
            layer.adapt(vocab_dataset)
            int_data = layer(input_data)
            model = keras.Model(inputs=input_data, outputs=int_data)

        output_dataset = model.predict(input_dataset)
        self.assertAllEqual(expected_output, output_dataset)
    def test_distribution_strategy_output_with_adapt(self, strategy):
        vocab_data = [[
            "earth", "earth", "earth", "earth", "wind", "wind", "wind", "and",
            "and", "fire"
        ]]
        vocab_dataset = dataset_ops.Dataset.from_tensors(vocab_data)
        input_array = np.array([["earth", "wind", "and", "fire"],
                                ["fire", "and", "earth", "michigan"]])
        input_dataset = dataset_ops.Dataset.from_tensor_slices(
            input_array).batch(2, drop_remainder=True)

        expected_output = [[2, 3, 4, 5], [5, 4, 2, 1]]

        config.set_soft_device_placement(True)

        with strategy.scope():
            input_data = keras.Input(shape=(None, ), dtype=dtypes.string)
            layer = text_vectorization.TextVectorization(
                max_tokens=None,
                standardize=None,
                split=None,
                output_mode=text_vectorization.INT)
            layer.adapt(vocab_dataset)
            int_data = layer(input_data)
            model = keras.Model(inputs=input_data, outputs=int_data)

        output_dataset = model.predict(input_dataset)
        self.assertAllEqual(expected_output, output_dataset)