def test_distribution(self, strategy):
        if "CentralStorage" in type(strategy).__name__:
            self.skipTest("Does not work with CentralStorageStrategy yet.")
        # TODO(b/159738418): large image input causes OOM in ubuntu multi gpu.
        np_images = np.random.random((32, 32, 32, 3)).astype(np.float32)
        image_dataset = tf.data.Dataset.from_tensor_slices(np_images).batch(
            16, drop_remainder=True)

        with strategy.scope():
            input_data = keras.Input(shape=(32, 32, 3), dtype=tf.float32)
            image_preprocessor = keras.Sequential([
                image_preprocessing.Resizing(height=256, width=256),
                image_preprocessing.RandomCrop(height=224, width=224),
                image_preprocessing.RandomTranslation(.1, .1),
                image_preprocessing.RandomRotation(.2),
                image_preprocessing.RandomFlip(),
                image_preprocessing.RandomZoom(.2, .2)
            ])
            preprocessed_image = image_preprocessor(input_data)
            flatten_layer = keras.layers.Flatten(data_format="channels_last")
            output = flatten_layer(preprocessed_image)
            cls_layer = keras.layers.Dense(units=1, activation="sigmoid")
            output = cls_layer(output)
            model = keras.Model(inputs=input_data, outputs=output)
        _ = model.predict(image_dataset)
    def bm_layer_implementation(self, batch_size):
        with tf.device("/gpu:0"):
            img = keras.Input(shape=(256, 256, 3), dtype=tf.float32)
            preprocessor = keras.Sequential([
                image_preprocessing.Resizing(224, 224),
                image_preprocessing.RandomCrop(height=224, width=224),
                image_preprocessing.RandomRotation(factor=(.2, .4)),
                image_preprocessing.RandomFlip(mode="horizontal"),
                image_preprocessing.RandomZoom(.2, .2)
            ])
            _ = preprocessor(img)

            num_repeats = 5
            starts = []
            ends = []
            for _ in range(num_repeats):
                ds = tf.data.Dataset.from_tensor_slices(
                    np.random.random((batch_size, 256, 256, 3)))
                ds = ds.shuffle(batch_size * 100)
                ds = ds.batch(batch_size)
                ds = ds.prefetch(batch_size)
                starts.append(time.time())
                count = 0
                # Benchmarked code begins here.
                for i in ds:
                    _ = preprocessor(i)
                    count += 1
                # Benchmarked code ends here.
                ends.append(time.time())

        avg_time = np.mean(np.array(ends) - np.array(starts)) / count
        name = "image_preprocessing|batch_%s" % batch_size
        baseline = self.run_dataset_implementation(batch_size)
        extras = {
            "dataset implementation baseline": baseline,
            "delta seconds": (baseline - avg_time),
            "delta percent": ((baseline - avg_time) / baseline) * 100
        }
        self.report_benchmark(iters=num_repeats,
                              wall_time=avg_time,
                              extras=extras,
                              name=name)