Exemplo n.º 1
0
def _get_assigned_gpu_or_default(default):
    from horovod.spark.task import get_available_devices
    available_devices = get_available_devices()
    if available_devices:
        # if GPU-aware scheduling is available, pin to the assigned GPU index
        return int(available_devices[0])
    else:
        # pin to default GPU index (local rank)
        return default
Exemplo n.º 2
0
    def train_fn(model_bytes):
        # Make sure pyarrow is referenced before anything else to avoid segfault due to conflict
        # with TensorFlow libraries.  Use `pa` package reference to ensure it's loaded before
        # functions like `deserialize_model` which are implemented at the top level.
        # See https://jira.apache.org/jira/browse/ARROW-3346
        pa

        import atexit
        import horovod.tensorflow.keras as hvd
        from horovod.spark.task import get_available_devices
        import os
        from petastorm import make_batch_reader
        from petastorm.tf_utils import make_petastorm_dataset
        import tempfile
        import tensorflow as tf
        import tensorflow.keras.backend as K
        import shutil

        # Horovod: initialize Horovod inside the trainer.
        hvd.init()

        # Horovod: pin GPU to be used to process local rank (one GPU per process), if GPUs are available.
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.gpu_options.visible_device_list = get_available_devices()[0]
        K.set_session(tf.Session(config=config))

        # Horovod: restore from checkpoint, use hvd.load_model under the hood.
        model = deserialize_model(model_bytes, hvd.load_model)

        # Horovod: adjust learning rate based on number of processes.
        scaled_lr = K.get_value(model.optimizer.lr) * hvd.size()
        K.set_value(model.optimizer.lr, scaled_lr)

        # Horovod: print summary logs on the first worker.
        verbose = 2 if hvd.rank() == 0 else 0

        callbacks = [
            # Horovod: broadcast initial variable states from rank 0 to all other processes.
            # This is necessary to ensure consistent initialization of all workers when
            # training is started with random weights or restored from a checkpoint.
            hvd.callbacks.BroadcastGlobalVariablesCallback(root_rank=0),

            # Horovod: average metrics among workers at the end of every epoch.
            #
            # Note: This callback must be in the list before the ReduceLROnPlateau,
            # TensorBoard, or other metrics-based callbacks.
            hvd.callbacks.MetricAverageCallback(),

            # Horovod: using `lr = 1.0 * hvd.size()` from the very beginning leads to worse final
            # accuracy. Scale the learning rate `lr = 1.0` ---> `lr = 1.0 * hvd.size()` during
            # the first five epochs. See https://arxiv.org/abs/1706.02677 for details.
            hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=5, initial_lr=scaled_lr, verbose=verbose),

            # Reduce LR if the metric is not improved for 10 epochs, and stop training
            # if it has not improved for 20 epochs.
            tf.keras.callbacks.ReduceLROnPlateau(monitor='val_exp_rmspe', patience=10, verbose=verbose),
            tf.keras.callbacks.EarlyStopping(monitor='val_exp_rmspe', mode='min', patience=20, verbose=verbose),
            tf.keras.callbacks.TerminateOnNaN()
        ]

        # Model checkpoint location.
        ckpt_dir = tempfile.mkdtemp()
        ckpt_file = os.path.join(ckpt_dir, 'checkpoint.h5')
        atexit.register(lambda: shutil.rmtree(ckpt_dir))

        # Horovod: save checkpoints only on the first worker to prevent other workers from corrupting them.
        if hvd.rank() == 0:
            callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_file, monitor='val_exp_rmspe', mode='min',
                                                                save_best_only=True))

        # Make Petastorm readers.
        with make_batch_reader('%s/train_df.parquet' % args.data_dir, num_epochs=None,
                               cur_shard=hvd.rank(), shard_count=hvd.size(),
                               hdfs_driver=PETASTORM_HDFS_DRIVER) as train_reader:
            with make_batch_reader('%s/val_df.parquet' % args.data_dir, num_epochs=None,
                                   cur_shard=hvd.rank(), shard_count=hvd.size(),
                                   hdfs_driver=PETASTORM_HDFS_DRIVER) as val_reader:
                # Convert readers to tf.data.Dataset.
                train_ds = make_petastorm_dataset(train_reader) \
                    .apply(tf.data.experimental.unbatch()) \
                    .shuffle(int(train_rows / hvd.size())) \
                    .batch(args.batch_size) \
                    .map(lambda x: (tuple(getattr(x, col) for col in all_cols), tf.log(x.Sales)))

                val_ds = make_petastorm_dataset(val_reader) \
                    .apply(tf.data.experimental.unbatch()) \
                    .batch(args.batch_size) \
                    .map(lambda x: (tuple(getattr(x, col) for col in all_cols), tf.log(x.Sales)))

                history = model.fit(train_ds,
                                    validation_data=val_ds,
                                    steps_per_epoch=int(train_rows / args.batch_size / hvd.size()),
                                    validation_steps=int(val_rows / args.batch_size / hvd.size()),
                                    callbacks=callbacks,
                                    verbose=verbose,
                                    epochs=args.epochs)

        # Dataset API usage currently displays a wall of errors upon termination.
        # This global model registration ensures clean termination.
        # Tracked in https://github.com/tensorflow/tensorflow/issues/24570
        globals()['_DATASET_FINALIZATION_HACK'] = model

        if hvd.rank() == 0:
            with open(ckpt_file, 'rb') as f:
                return history.history, f.read()
Exemplo n.º 3
0
 def fn():
     hvd.init()
     devices = get_available_devices()
     return devices, hvd.local_rank()