import os import json import tensorflow as tf #os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import mnist per_worker_batch_size = 64 tf_config = json.loads(os.environ['TF_CONFIG']) num_workers = len(tf_config['cluster']['worker']) strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() global_batch_size = per_worker_batch_size * num_workers multi_worker_dataset = mnist.mnist_dataset(global_batch_size) with strategy.scope(): # Model building/compiling need to be within `strategy.scope()`. multi_worker_model = mnist.build_and_compile_cnn_model() multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)
cases where workers die or are otherwise unstable. You do this by preserving training state in the distributed file system of your choice, such that upon restart of the instance that previously failed or preempted, the training state is recovered. Since all the workers are kept in sync in terms of training epochs and steps, other workers would need to wait for the failed or preempted worker to restart to continue.[3] """ # get the dataset dataset = mnist.mnist_dataset(batch_size) """ This automatically shards the data based on the sharding policy provided by the tf distribute statergy[3]. """ # Model building/compiling need to be within `strategy.scope()`. with strategy.scope(): model = mnist.build_and_compile_cnn_model() # fit the model model.fit(dataset, epochs=args.epochs, steps_per_epoch=args.steps) """ References: [1]. https://www.codingame.com/playgrounds/349/introduction-to-mpi/introduction-to-collective-communications [2]. https://www.tensorflow.org/guide/distributed_training [3]. https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras """