def test_distribution(self, strategy): if "CentralStorage" in type(strategy).__name__: self.skipTest("Does not work with CentralStorageStrategy yet.") # TODO(b/159738418): large image input causes OOM in ubuntu multi gpu. np_images = np.random.random((32, 32, 32, 3)).astype(np.float32) image_dataset = tf.data.Dataset.from_tensor_slices(np_images).batch( 16, drop_remainder=True) with strategy.scope(): input_data = keras.Input(shape=(32, 32, 3), dtype=tf.float32) image_preprocessor = keras.Sequential([ image_preprocessing.Resizing(height=256, width=256), image_preprocessing.RandomCrop(height=224, width=224), image_preprocessing.RandomTranslation(.1, .1), image_preprocessing.RandomRotation(.2), image_preprocessing.RandomFlip(), image_preprocessing.RandomZoom(.2, .2) ]) preprocessed_image = image_preprocessor(input_data) flatten_layer = keras.layers.Flatten(data_format="channels_last") output = flatten_layer(preprocessed_image) cls_layer = keras.layers.Dense(units=1, activation="sigmoid") output = cls_layer(output) model = keras.Model(inputs=input_data, outputs=output) _ = model.predict(image_dataset)
def bm_layer_implementation(self, batch_size): with tf.device("/gpu:0"): img = keras.Input(shape=(256, 256, 3), dtype=tf.float32) preprocessor = keras.Sequential([ image_preprocessing.Resizing(224, 224), image_preprocessing.RandomCrop(height=224, width=224), image_preprocessing.RandomRotation(factor=(.2, .4)), image_preprocessing.RandomFlip(mode="horizontal"), image_preprocessing.RandomZoom(.2, .2) ]) _ = preprocessor(img) num_repeats = 5 starts = [] ends = [] for _ in range(num_repeats): ds = tf.data.Dataset.from_tensor_slices( np.random.random((batch_size, 256, 256, 3))) ds = ds.shuffle(batch_size * 100) ds = ds.batch(batch_size) ds = ds.prefetch(batch_size) starts.append(time.time()) count = 0 # Benchmarked code begins here. for i in ds: _ = preprocessor(i) count += 1 # Benchmarked code ends here. ends.append(time.time()) avg_time = np.mean(np.array(ends) - np.array(starts)) / count name = "image_preprocessing|batch_%s" % batch_size baseline = self.run_dataset_implementation(batch_size) extras = { "dataset implementation baseline": baseline, "delta seconds": (baseline - avg_time), "delta percent": ((baseline - avg_time) / baseline) * 100 } self.report_benchmark(iters=num_repeats, wall_time=avg_time, extras=extras, name=name)