def _get_assigned_gpu_or_default(default): from horovod.spark.task import get_available_devices available_devices = get_available_devices() if available_devices: # if GPU-aware scheduling is available, pin to the assigned GPU index return int(available_devices[0]) else: # pin to default GPU index (local rank) return default
def train_fn(model_bytes): # Make sure pyarrow is referenced before anything else to avoid segfault due to conflict # with TensorFlow libraries. Use `pa` package reference to ensure it's loaded before # functions like `deserialize_model` which are implemented at the top level. # See https://jira.apache.org/jira/browse/ARROW-3346 pa import atexit import horovod.tensorflow.keras as hvd from horovod.spark.task import get_available_devices import os from petastorm import make_batch_reader from petastorm.tf_utils import make_petastorm_dataset import tempfile import tensorflow as tf import tensorflow.keras.backend as K import shutil # Horovod: initialize Horovod inside the trainer. hvd.init() # Horovod: pin GPU to be used to process local rank (one GPU per process), if GPUs are available. config = tf.ConfigProto() config.gpu_options.allow_growth = True config.gpu_options.visible_device_list = get_available_devices()[0] K.set_session(tf.Session(config=config)) # Horovod: restore from checkpoint, use hvd.load_model under the hood. model = deserialize_model(model_bytes, hvd.load_model) # Horovod: adjust learning rate based on number of processes. scaled_lr = K.get_value(model.optimizer.lr) * hvd.size() K.set_value(model.optimizer.lr, scaled_lr) # Horovod: print summary logs on the first worker. verbose = 2 if hvd.rank() == 0 else 0 callbacks = [ # Horovod: broadcast initial variable states from rank 0 to all other processes. # This is necessary to ensure consistent initialization of all workers when # training is started with random weights or restored from a checkpoint. hvd.callbacks.BroadcastGlobalVariablesCallback(root_rank=0), # Horovod: average metrics among workers at the end of every epoch. # # Note: This callback must be in the list before the ReduceLROnPlateau, # TensorBoard, or other metrics-based callbacks. hvd.callbacks.MetricAverageCallback(), # Horovod: using `lr = 1.0 * hvd.size()` from the very beginning leads to worse final # accuracy. Scale the learning rate `lr = 1.0` ---> `lr = 1.0 * hvd.size()` during # the first five epochs. See https://arxiv.org/abs/1706.02677 for details. hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=5, initial_lr=scaled_lr, verbose=verbose), # Reduce LR if the metric is not improved for 10 epochs, and stop training # if it has not improved for 20 epochs. tf.keras.callbacks.ReduceLROnPlateau(monitor='val_exp_rmspe', patience=10, verbose=verbose), tf.keras.callbacks.EarlyStopping(monitor='val_exp_rmspe', mode='min', patience=20, verbose=verbose), tf.keras.callbacks.TerminateOnNaN() ] # Model checkpoint location. ckpt_dir = tempfile.mkdtemp() ckpt_file = os.path.join(ckpt_dir, 'checkpoint.h5') atexit.register(lambda: shutil.rmtree(ckpt_dir)) # Horovod: save checkpoints only on the first worker to prevent other workers from corrupting them. if hvd.rank() == 0: callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_file, monitor='val_exp_rmspe', mode='min', save_best_only=True)) # Make Petastorm readers. with make_batch_reader('%s/train_df.parquet' % args.data_dir, num_epochs=None, cur_shard=hvd.rank(), shard_count=hvd.size(), hdfs_driver=PETASTORM_HDFS_DRIVER) as train_reader: with make_batch_reader('%s/val_df.parquet' % args.data_dir, num_epochs=None, cur_shard=hvd.rank(), shard_count=hvd.size(), hdfs_driver=PETASTORM_HDFS_DRIVER) as val_reader: # Convert readers to tf.data.Dataset. train_ds = make_petastorm_dataset(train_reader) \ .apply(tf.data.experimental.unbatch()) \ .shuffle(int(train_rows / hvd.size())) \ .batch(args.batch_size) \ .map(lambda x: (tuple(getattr(x, col) for col in all_cols), tf.log(x.Sales))) val_ds = make_petastorm_dataset(val_reader) \ .apply(tf.data.experimental.unbatch()) \ .batch(args.batch_size) \ .map(lambda x: (tuple(getattr(x, col) for col in all_cols), tf.log(x.Sales))) history = model.fit(train_ds, validation_data=val_ds, steps_per_epoch=int(train_rows / args.batch_size / hvd.size()), validation_steps=int(val_rows / args.batch_size / hvd.size()), callbacks=callbacks, verbose=verbose, epochs=args.epochs) # Dataset API usage currently displays a wall of errors upon termination. # This global model registration ensures clean termination. # Tracked in https://github.com/tensorflow/tensorflow/issues/24570 globals()['_DATASET_FINALIZATION_HACK'] = model if hvd.rank() == 0: with open(ckpt_file, 'rb') as f: return history.history, f.read()
def fn(): hvd.init() devices = get_available_devices() return devices, hvd.local_rank()