def test_some_processing_functions(synthetic_dataset, reader_factory): """Try several ``tf.data.Dataset`` dataset operations on make_petastorm_dataset""" # reader1 will have a single row with id=1, reader2: a single row with id=2 # Using functools.partial(_eq, 1)) which is equivalent to lambda x: x==1 because standard python pickle # can not pickle this lambda with reader_factory(synthetic_dataset.url, predicate=in_lambda(['id'], functools.partial(operator.eq, 1))) as reader1: with reader_factory(synthetic_dataset.url, predicate=in_lambda(['id'], functools.partial( operator.eq, 2))) as reader2: dataset = make_petastorm_dataset(reader1) \ .prefetch(10) \ .concatenate(make_petastorm_dataset(reader2)) \ .map(lambda x: x.id) \ .batch(2) next_sample = dataset.make_one_shot_iterator().get_next() with tf.Session() as sess: # 'actual' is expected to be content of id column of a concatenated dataset actual = sess.run(next_sample) np.testing.assert_array_equal(actual, [1, 2])
def test_dataset_on_ngram_not_supported(synthetic_dataset, reader_factory): ngram = NGram({ 0: list(_NOT_NULL_FIELDS), 1: [TestSchema.id] }, 100, TestSchema.id) with reader_factory(synthetic_dataset.url, schema_fields=ngram) as reader: with pytest.raises(NotImplementedError): make_petastorm_dataset(reader)
def test_with_one_shot_iterator(synthetic_dataset, reader_factory): """Just a bunch of read and compares of all values to the expected values""" with reader_factory(synthetic_dataset.url) as reader: dataset = make_petastorm_dataset(reader) iterator = dataset.make_one_shot_iterator() # Make sure we have static shape info for all fields for shape in dataset.output_shapes: # TODO(yevgeni): check that the shapes are actually correct, not just not None assert shape.dims is not None # Read a bunch of entries from the dataset and compare the data to reference with tf.Session() as sess: iterator = iterator.get_next() for _, _ in enumerate(synthetic_dataset.data): actual = sess.run(iterator)._asdict() expected = next(d for d in synthetic_dataset.data if d['id'] == actual['id']) for key in actual.keys(): if isinstance(expected[key], str): # Tensorflow returns all strings as bytes in python3. So we will need to decode it actual_value = actual[key].decode() elif isinstance(expected[key], np.ndarray) and expected[key].dtype.type == np.unicode_: actual_value = np.array([item.decode() for item in actual[key]]) else: actual_value = actual[key] np.testing.assert_equal(actual_value, expected[key]) # Exhausted one full epoch. Fetching next value should trigger OutOfRangeError with pytest.raises(tf.errors.OutOfRangeError): sess.run(iterator)
def test_with_dataset_repeat_after_cache(synthetic_dataset, reader_factory): """ Check if ``tf.data.Dataset``'s ``repeat`` works after ``tf.data.Dataset``'s ``cache``.""" epochs = 3 with reader_factory(synthetic_dataset.url, schema_fields=[TestSchema.id]) as reader: dataset = make_petastorm_dataset(reader) dataset = dataset.cache() dataset = dataset.repeat(epochs) iterator = dataset.make_one_shot_iterator() it_op = iterator.get_next() # Check if dataset generates same result in every epoch. with tf.Session() as sess: with pytest.warns(None): # Expect no warnings since cache() is called before repeat() for _ in range(epochs): actual_res = [] for _, _ in enumerate(synthetic_dataset.data): actual = sess.run(it_op)._asdict() actual_res.append(actual["id"]) expected_res = list(range(len(synthetic_dataset.data))) # sort dataset output since row_groups are shuffled from reader. np.testing.assert_equal(sorted(actual_res), expected_res) # Exhausted all epochs. Fetching next value should trigger OutOfRangeError with pytest.raises(tf.errors.OutOfRangeError): sess.run(it_op)
def fn(reader, batch_size, shuffle_buffer_size, is_batch_reader, shuffle=False, cache=False, seed=None): from petastorm.tf_utils import make_petastorm_dataset dataset = make_petastorm_dataset(reader) if is_batch_reader: dataset = dataset.apply(tf.data.experimental.unbatch()) # Apply cache() before shuffle, so we can reshuffle in each iteration. if cache: dataset = dataset.cache() if shuffle: dataset = dataset.shuffle(shuffle_buffer_size, seed=seed) # Use tf.data.Dataset.repeat() to set up an infinite iterator # and to enable ranks to perform training and validation with # unequal number of samples. # FIXME(chongxiaoc): Use a very large number (10^9) for enough loops. # None and -1 are not working with petastorm dataset for repeating. # Verify this parameter again in future with new TF and Petastorm versions. dataset = dataset.repeat(1000000000) # Decompress sparse data if necessary if has_sparse_col: dataset = dataset.batch(1).map(reshape) dataset = dataset.batch(batch_size).map(prep_data_tf_keras) if hasattr(tf.data, 'AUTOTUNE'): dataset = dataset.prefetch(tf.data.AUTOTUNE) else: dataset = dataset.prefetch(1) return dataset
def get_dataset(reader, shuffle, batch): def make_window_dataset(ds, window_size=64, shift=64, stride=1): windows = ds.window(window_size, shift=shift, stride=stride) def sub_to_batch(sub): return sub.batch(window_size, drop_remainder=True) windows = windows.flat_map(sub_to_batch) return windows def create_dataset(features, label): features_ = tf.data.Dataset.from_tensor_slices( tf.reshape(features, shape=[128, 2])) label_ = tf.concat([label, tf.constant([0.])], 0) return (make_window_dataset(features_), make_window_dataset( tf.data.Dataset.from_tensor_slices(label_[1:]))) def func(x, y): return tf.data.Dataset.zip((x, y)) def mapping(x): return (tf.stack([x.user_list, x.item_list], 2), x.item_list) def reshape(x, y): return (x, tf.cast(tf.reshape(y, [-1]), tf.int32)) features = make_petastorm_dataset(reader).map(mapping).unbatch() features_windows = features.map(create_dataset).flat_map(func) return features_windows.shuffle(shuffle).batch(batch).map(reshape)
def test_with_tf_data_api(synthetic_dataset): """Verify that WeightedSamplingReader is compatible with make_petastorm_dataset""" np.random.seed(42) fields_to_read = ['id.*', 'image_png'] # Use cur_shard=0, shard_count=2 to get only half samples from the second reader. readers = [make_reader(synthetic_dataset.url, schema_fields=fields_to_read, workers_count=1), make_reader(synthetic_dataset.url, schema_fields=fields_to_read, workers_count=1, cur_shard=0, shard_count=2)] with WeightedSamplingReader(readers, [0.5, 0.5]) as mixer: dataset = make_petastorm_dataset(mixer) iterator = dataset.make_one_shot_iterator() tensor = iterator.get_next() rows_count = 0 with tf.Session() as sess: while True: try: sess.run(tensor) rows_count += 1 except tf.errors.OutOfRangeError: break # We expect iterations to finish once the second read has exhausted its samples. For each sample in the # second reaader we read approximately 1 sample from the first. expected_rows_approx = len(synthetic_dataset.data) np.testing.assert_allclose(rows_count, expected_rows_approx, atol=20)
def __enter__(self): # import locally to avoid importing tensorflow globally. from petastorm.tf_utils import make_petastorm_dataset import tensorflow.compat.v1 as tf # pylint: disable=import-error _wait_file_available(self.parquet_file_url_list) self.reader = make_batch_reader(self.parquet_file_url_list, **self.petastorm_reader_kwargs) # unroll dataset dataset = make_petastorm_dataset(self.reader).flat_map( tf.data.Dataset.from_tensor_slices) # TODO: auto tune best batch size in default case. batch_size = self.batch_size or 32 dataset = dataset.batch(batch_size=batch_size) prefetch = self.prefetch if prefetch is None: if LooseVersion(tf.__version__) >= LooseVersion('1.14'): # We can make prefetch optimization prefetch = tf.data.experimental.AUTOTUNE else: prefetch = 1 dataset = dataset.prefetch(prefetch) return dataset
def __init__(self, data_url, batch_size, prefetch, preproc_fn, preproc_parallelism): """ :param data_url: A string specifying the data URL. :param batch_size: batch size of the generated tf.data.dataset :param prefetch: prefetch for tf dataset :param preproc_fn: preprocessing function :param preproc_parallelism: parallelism for preprocessing function """ from petastorm.tf_utils import make_petastorm_dataset import tensorflow as tf def support_prefetch_and_autotune(): return LooseVersion(tf.__version__) >= LooseVersion('1.14') self.reader = petastorm.make_batch_reader(data_url) self.dataset = make_petastorm_dataset(self.reader) \ .flat_map(tf.data.Dataset.from_tensor_slices) \ self.dataset = self.dataset.batch(batch_size=batch_size) if support_prefetch_and_autotune(): if prefetch is None: prefetch = tf.data.experimental.AUTOTUNE if prefetch != 0: self.dataset = self.dataset.prefetch(prefetch) if preproc_fn is not None: if preproc_parallelism is None: if support_prefetch_and_autotune(): preproc_parallelism = tf.data.experimental.AUTOTUNE else: preproc_parallelism = 1 self.dataset = self.dataset.map(preproc_fn, preproc_parallelism)
def __init__(self, data_url): """ :param data_url: A string specifying the data URL. """ from petastorm.tf_utils import make_petastorm_dataset self.reader = make_batch_reader(data_url) self.dataset = make_petastorm_dataset(self.reader)
def _input_fn(reader, batch_size, num_parallel_batches): dataset = ( make_petastorm_dataset(reader) # Per Petastorm docs, do not add a .repeat(num_epochs) here # Petastorm will cycle indefinitely through the data given `num_epochs=None` # provided to make_reader .apply( tf.contrib.data.map_and_batch( streaming_parser, batch_size=batch_size, num_parallel_batches=num_parallel_batches))) return dataset
def test_dataset_with_ngrams(synthetic_dataset, reader_factory): ngram = NGram({-1: [TestSchema.id, TestSchema.image_png], 2: [TestSchema.id]}, 100, TestSchema.id) with reader_factory(synthetic_dataset.url, schema_fields=ngram, num_epochs=1) as reader: dataset = make_petastorm_dataset(reader) next_sample = dataset.make_one_shot_iterator().get_next() with tf.Session() as sess: while True: try: actual = sess.run(next_sample) assert actual[-1].id + 3 == actual[2].id assert np.all(actual[-1].image_png.shape > (10, 10, 3)) except tf.errors.OutOfRangeError: break
def fn(reader, shuffle_buffer_size, shuffle=False): from petastorm.tf_utils import make_petastorm_dataset dataset = make_petastorm_dataset(reader) \ .apply(tf.data.experimental.unbatch()) if shuffle: dataset = dataset.shuffle(shuffle_buffer_size) # Decompress sparse data if necessary if has_sparse_col: dataset = dataset.batch(1).map(reshape) dataset = dataset.batch(batch_size).map(prep_data_tf_keras) return dataset
def fn(reader, transformation_fn): from petastorm.tf_utils import make_petastorm_dataset dataset = make_petastorm_dataset(reader) # Decompress sparse data if necessary if has_sparse_col: dataset = dataset.batch(1).map(reshape, num_parallel_calls=tf.data.experimental.AUTOTUNE) if transformation_fn: # user provided custom transformation function dataset = transformation_fn(dataset) dataset = dataset.batch(batch_size).map(prep_data_tf_keras, num_parallel_calls=tf.data.experimental.AUTOTUNE) return dataset.prefetch(tf.data.experimental.AUTOTUNE)
def tensorflow_hello_world(dataset_url='file:///tmp/hello_world_dataset'): # Example: tf_tensors will return tensors with dataset data with make_reader(dataset_url) as reader: tensor = tf_tensors(reader) with tf.Session() as sess: sample = sess.run(tensor) print(sample.id) # Example: use tf.data.Dataset API with make_reader(dataset_url) as reader: dataset = make_petastorm_dataset(reader) iterator = dataset.make_one_shot_iterator() tensor = iterator.get_next() with tf.Session() as sess: sample = sess.run(tensor) print(sample.id)
def test_non_petastorm_with_many_colums_with_one_shot_iterator(many_columns_non_petastorm_dataset): """Just a bunch of read and compares of all values to the expected values""" with make_batch_reader(many_columns_non_petastorm_dataset.url, workers_count=1) as reader: dataset = make_petastorm_dataset(reader) iterator = dataset.make_one_shot_iterator() # Make sure we have static shape info for all fields for shape in dataset.output_shapes: # TODO(yevgeni): check that the shapes are actually correct, not just not None assert shape.dims is not None # Read a bunch of entries from the dataset and compare the data to reference with tf.Session() as sess: iterator = iterator.get_next() sample = sess.run(iterator)._asdict() assert set(sample.keys()) == set(many_columns_non_petastorm_dataset.data[0].keys())
def tensorflow_hello_world(dataset_url='file:///tmp/external_dataset'): # Example: tf_tensors will return tensors with dataset data with make_batch_reader(dataset_url) as reader: tensor = tf_tensors(reader) with tf.Session() as sess: # Because we are using make_batch_reader(), each read returns a batch of rows instead of a single row batched_sample = sess.run(tensor) print("id batch: {0}".format(batched_sample.id)) # Example: use tf.data.Dataset API with make_batch_reader(dataset_url) as reader: dataset = make_petastorm_dataset(reader) iterator = dataset.make_one_shot_iterator() tensor = iterator.get_next() with tf.Session() as sess: batched_sample = sess.run(tensor) print("id batch: {0}".format(batched_sample.id))
def test_non_petastorm_with_many_colums_epoch_count(many_columns_non_petastorm_dataset): """Just a bunch of read and compares of all values to the expected values""" expected_num_epochs = 4 with make_batch_reader(many_columns_non_petastorm_dataset.url, workers_count=1, num_epochs=expected_num_epochs) as reader: dataset = make_petastorm_dataset(reader) iterator = dataset.make_one_shot_iterator() # Read a bunch of entries from the dataset and compare the data to reference with tf.Session() as sess: get_next = iterator.get_next() rows_count = 0 while True: try: sample = sess.run(get_next)._asdict() rows_count += sample['col_0'].shape[0] except tf.errors.OutOfRangeError: break assert expected_num_epochs * len(many_columns_non_petastorm_dataset.data) == rows_count
def test_with_dataset_repeat(synthetic_dataset, reader_factory): """``tf.data.Dataset``'s ``repeat`` should not be used on ``make_petastorm_dataset`` due to high costs of ``Reader initialization``. A user should use ``Reader`` built-in epochs support. Check that we raise an error to alert of misuse.""" with reader_factory(synthetic_dataset.url) as reader: dataset = make_petastorm_dataset(reader) dataset = dataset.repeat(2) iterator = dataset.make_one_shot_iterator() # Read a bunch of entries from the dataset and compare the data to reference with tf.Session() as sess: iterator = iterator.get_next() for _, _ in enumerate(synthetic_dataset.data): sess.run(iterator) with pytest.raises(tf.errors.UnknownError, match=r'.*Multiple iterations.*'): sess.run(iterator)
def test_with_dataset_repeat(synthetic_dataset, reader_factory): """``tf.data.Dataset``'s ``repeat`` is not recommended to used on ``make_petastorm_dataset`` due to high costs of ``Reader initialization``. A user should use ``Reader`` built-in epochs support, or use ``tf.data.Dataset``'s ``cache``. Check that we trigger a warning to alert of misuse.""" with reader_factory(synthetic_dataset.url) as reader: dataset = make_petastorm_dataset(reader) dataset = dataset.repeat(2) iterator = dataset.make_one_shot_iterator() # Read a bunch of entries from the dataset and compare the data to reference with tf.Session() as sess: iterator = iterator.get_next() for _, _ in enumerate(synthetic_dataset.data): sess.run(iterator) match_str = 'Running multiple iterations over make_petastorm_dataset is not recommend for performance issue' with pytest.warns(UserWarning, match=match_str): sess.run(iterator)
def initialize_batcher(self, batch_size=128, should_shuffle=True, shuffle_buffer_size=None, seed=0, ignore_last=False, horovod=None): cur_shard, shard_count = None, None if horovod: cur_shard, shard_count = horovod.rank(), horovod.size() with make_batch_reader(self.url, cur_shard=cur_shard, shard_count=shard_count, num_epochs=None) as reader: total_samples = self.size local_samples = int(total_samples / shard_count) if shard_count else total_samples dataset = make_petastorm_dataset(reader) dataset = dataset.unbatch() if should_shuffle: rows_per_piece = max([ piece.get_metadata().num_rows for piece in reader.dataset.pieces ]) buffer_size = shuffle_buffer_size or min( rows_per_piece, local_samples) dataset = dataset.shuffle(buffer_size) dataset = dataset.batch(batch_size) steps_per_epoch = math.ceil(local_samples / batch_size) batcher = IterableBatcher(self, dataset, steps_per_epoch, ignore_last=ignore_last) yield batcher
def train_fn(model_bytes): # Make sure pyarrow is referenced before anything else to avoid segfault due to conflict # with TensorFlow libraries. Use `pa` package reference to ensure it's loaded before # functions like `deserialize_model` which are implemented at the top level. # See https://jira.apache.org/jira/browse/ARROW-3346 pa import atexit import horovod.tensorflow.keras as hvd from horovod.spark.task import get_available_devices import os from petastorm import make_batch_reader from petastorm.tf_utils import make_petastorm_dataset import tempfile import tensorflow as tf import tensorflow.keras.backend as K import shutil # Horovod: initialize Horovod inside the trainer. hvd.init() # Horovod: pin GPU to be used to process local rank (one GPU per process), if GPUs are available. config = tf.ConfigProto() config.gpu_options.allow_growth = True config.gpu_options.visible_device_list = get_available_devices()[0] K.set_session(tf.Session(config=config)) # Horovod: restore from checkpoint, use hvd.load_model under the hood. model = deserialize_model(model_bytes, hvd.load_model) # Horovod: adjust learning rate based on number of processes. scaled_lr = K.get_value(model.optimizer.lr) * hvd.size() K.set_value(model.optimizer.lr, scaled_lr) # Horovod: print summary logs on the first worker. verbose = 2 if hvd.rank() == 0 else 0 callbacks = [ # Horovod: broadcast initial variable states from rank 0 to all other processes. # This is necessary to ensure consistent initialization of all workers when # training is started with random weights or restored from a checkpoint. hvd.callbacks.BroadcastGlobalVariablesCallback(root_rank=0), # Horovod: average metrics among workers at the end of every epoch. # # Note: This callback must be in the list before the ReduceLROnPlateau, # TensorBoard, or other metrics-based callbacks. hvd.callbacks.MetricAverageCallback(), # Horovod: using `lr = 1.0 * hvd.size()` from the very beginning leads to worse final # accuracy. Scale the learning rate `lr = 1.0` ---> `lr = 1.0 * hvd.size()` during # the first five epochs. See https://arxiv.org/abs/1706.02677 for details. hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=5, initial_lr=scaled_lr, verbose=verbose), # Reduce LR if the metric is not improved for 10 epochs, and stop training # if it has not improved for 20 epochs. tf.keras.callbacks.ReduceLROnPlateau(monitor='val_exp_rmspe', patience=10, verbose=verbose), tf.keras.callbacks.EarlyStopping(monitor='val_exp_rmspe', mode='min', patience=20, verbose=verbose), tf.keras.callbacks.TerminateOnNaN() ] # Model checkpoint location. ckpt_dir = tempfile.mkdtemp() ckpt_file = os.path.join(ckpt_dir, 'checkpoint.h5') atexit.register(lambda: shutil.rmtree(ckpt_dir)) # Horovod: save checkpoints only on the first worker to prevent other workers from corrupting them. if hvd.rank() == 0: callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_file, monitor='val_exp_rmspe', mode='min', save_best_only=True)) # Make Petastorm readers. with make_batch_reader('%s/train_df.parquet' % args.data_dir, num_epochs=None, cur_shard=hvd.rank(), shard_count=hvd.size(), hdfs_driver=PETASTORM_HDFS_DRIVER) as train_reader: with make_batch_reader('%s/val_df.parquet' % args.data_dir, num_epochs=None, cur_shard=hvd.rank(), shard_count=hvd.size(), hdfs_driver=PETASTORM_HDFS_DRIVER) as val_reader: # Convert readers to tf.data.Dataset. train_ds = make_petastorm_dataset(train_reader) \ .apply(tf.data.experimental.unbatch()) \ .shuffle(int(train_rows / hvd.size())) \ .batch(args.batch_size) \ .map(lambda x: (tuple(getattr(x, col) for col in all_cols), tf.log(x.Sales))) val_ds = make_petastorm_dataset(val_reader) \ .apply(tf.data.experimental.unbatch()) \ .batch(args.batch_size) \ .map(lambda x: (tuple(getattr(x, col) for col in all_cols), tf.log(x.Sales))) history = model.fit(train_ds, validation_data=val_ds, steps_per_epoch=int(train_rows / args.batch_size / hvd.size()), validation_steps=int(val_rows / args.batch_size / hvd.size()), callbacks=callbacks, verbose=verbose, epochs=args.epochs) # Dataset API usage currently displays a wall of errors upon termination. # This global model registration ensures clean termination. # Tracked in https://github.com/tensorflow/tensorflow/issues/24570 globals()['_DATASET_FINALIZATION_HACK'] = model if hvd.rank() == 0: with open(ckpt_file, 'rb') as f: return history.history, f.read()
# COMMAND ---------- import pyarrow.parquet as pq underscore_files = [ f for f in os.listdir(get_local_path(parquet_path)) if f.startswith("_") ] pq.EXCLUDED_PARQUET_PATHS.update(underscore_files) # COMMAND ---------- # We use make_batch_reader to load Parquet row groups into batches. # HINT: Use cur_shard and shard_count params to shard data in distributed training. petastorm_dataset_url = "file://" + get_local_path(parquet_path) with make_batch_reader(petastorm_dataset_url, num_epochs=100) as reader: dataset = make_petastorm_dataset(reader) \ .map(lambda x: (tf.reshape(x.features, [-1, 28, 28, 1]), tf.one_hot(x.label, 10))) model = get_model() optimizer = keras.optimizers.Adadelta() model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy']) model.fit(dataset, steps_per_epoch=10, epochs=10) # COMMAND ---------- # Clean up the working directory. dbutils.fs.rm(work_dir, recurse=True) # COMMAND ----------
def decode_mask(tensor): codec = NdarrayCodec() mask_np_array = codec.decode(TrainSchema[1], tensor[1]) return mask_np_array #transform = TransformSpec(decode_image_and_mask) with make_batch_reader('hdfs://node013.ib.cluster:8020/train/train_df.parquet', num_epochs=None, cur_shard=hvd.rank(), shard_count=hvd.size(), hdfs_driver='libhdfs') as train_reader: train_ds = make_petastorm_dataset(train_reader) \ .apply(tf.data.experimental.unbatch())\ .map(lambda tensor: (tf.py_func(decode_image, [tensor], tf.uint8), tf.py_func(decode_mask, [tensor], tf.uint8))) iterator = train_ds.make_one_shot_iterator() for x, y in iterator: print(x) #tensor = iterator.get_next() #with tf.Session() as sess: #print("v1", tf.shape(tensor)) #sample = sess.run(tf.shape(tensor)) #print(sample) #next_element = iterator.get_next() #sess.run(iterator.initializer) #while True: # try: # (features, labels) = sess.run(tensor)
def train_fn(model_bytes, batch_size, epochs, train_rows, val_rows, warmup_epochs=5, parquet_path=PARQUET_PATH, checkpoint_file_name=CHECKPOINT_FILE_NAME): ''' Hovorod training function Can be used standalone or passed into a hovorod spark distributor NOTE: This function should run withou any bugs by itself inputs: io.bytes model_bytes: A serialised compiled keras model containing the model, optimiser, etc int batch_size: Batch size int epochs: # epochs int train_rows: # rows in training set int val_rows: # rows in val set int warmup_epochs: see callbacks section string parquet_path: path to training, validation and test parquet files string checkpoint_file_name: name of checkpoint file return: dict history: dictionary containing the history of loss and metrics io.bytes best_model_bytes: best model, serialised ''' # Make sure pyarrow is referenced before anything else to avoid segfault due to conflict # with TensorFlow libraries. Use `pa` package reference to ensure it's loaded before # functions like `deserialize_model` which are implemented at the top level. # See https://jira.apache.org/jira/browse/ARROW-3346 pa # Horovod: initialize Horovod inside the trainer. hvd.init() # Horovod: pin GPU to be used to process local rank (one GPU per process), if GPUs are available. physical_devices = tf.config.list_physical_devices('GPU') if physical_devices is not None: tf.config.set_visible_devices(physical_devices, 'GPU') for device in physical_devices: tf.config.experimental.set_memory_growth(device, True) # Horovod: restore from checkpoint, use hvd.load_model under the hood. model = deserialize_model(model_bytes, hvd.load_model) # Horovod: adjust learning rate based on number of processes. scaled_lr = K.get_value(model.optimizer.lr) * hvd.size() K.set_value(model.optimizer.lr, scaled_lr) # Horovod: print summary logs on the first worker. verbose = 2 if hvd.rank() == 0 else 0 callbacks = [ # Horovod: broadcast initial variable states from rank 0 to all other processes. # This is necessary to ensure consistent initialization of all workers when # training is started with random weights or restored from a checkpoint. hvd.callbacks.BroadcastGlobalVariablesCallback(root_rank=0), # Horovod: average metrics among workers at the end of every epoch. # # Note: This callback must be in the list before the ReduceLROnPlateau, # TensorBoard, or other metrics-based callbacks. hvd.callbacks.MetricAverageCallback(), # Horovod: using `lr = 1.0 * hvd.size()` from the very beginning leads to worse final # accuracy. Scale the learning rate `lr = 1.0` ---> `lr = 1.0 * hvd.size()` during # the first five epochs. See https://arxiv.org/abs/1706.02677 for details. hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=warmup_epochs, initial_lr=scaled_lr, verbose=verbose), # Reduce LR if the metric is not improved for 10 epochs, and stop training # if it has not improved for 20 epochs. tf.keras.callbacks.ReduceLROnPlateau(monitor='val_mean_squared_error', patience=10, verbose=verbose), tf.keras.callbacks.EarlyStopping(monitor='val_mean_squared_error', mode='min', patience=20, verbose=verbose), tf.keras.callbacks.TerminateOnNaN() ] # Model checkpoint location. ckpt_dir = tempfile.mkdtemp() ckpt_file = os.path.join(ckpt_dir, checkpoint_file_name) atexit.register(lambda: shutil.rmtree(ckpt_dir)) # Horovod: save checkpoints only on the first worker to prevent other workers from corrupting them. if hvd.rank() == 0: callbacks.append( tf.keras.callbacks.ModelCheckpoint( ckpt_file, monitor='val_mean_squared_error', mode='min', save_best_only=True)) # Make Petastorm readers. with make_batch_reader( 'file://' + os.path.join(os.getcwd(), parquet_path, 'train_df.parquet'), num_epochs=None, cur_shard=hvd.rank(), shard_count=hvd.size(), hdfs_driver=PETASTORM_HDFS_DRIVER) as train_reader: with make_batch_reader( 'file://' + os.path.join(os.getcwd(), parquet_path, 'val_df.parquet'), num_epochs=None, cur_shard=hvd.rank(), shard_count=hvd.size(), hdfs_driver=PETASTORM_HDFS_DRIVER) as val_reader: # Convert readers to tf.data.Dataset. train_ds = make_petastorm_dataset(train_reader) \ .apply(tf.data.experimental.unbatch()) \ .shuffle(int(train_rows / hvd.size())) \ .map(lambda x: encode_map_fn(x.new_movie_title,x.avg_rating)) \ .padded_batch(batch_size) val_ds = make_petastorm_dataset(val_reader) \ .apply(tf.data.experimental.unbatch()) \ .map(lambda x: encode_map_fn(x.new_movie_title,x.avg_rating)) \ .padded_batch(batch_size) history = model.fit( train_ds, validation_data=val_ds, steps_per_epoch=int(train_rows / batch_size / hvd.size()), validation_steps=int(val_rows / batch_size / hvd.size()), callbacks=callbacks, verbose=verbose, epochs=epochs) # Dataset API usage currently displays a wall of errors upon termination. # This global model registration ensures clean termination. # Tracked in https://github.com/tensorflow/tensorflow/issues/24570 globals()['_DATASET_FINALIZATION_HACK'] = model if hvd.rank() == 0: with open(ckpt_file, 'rb') as f: return history.history, f.read()