Exemplo n.º 1
0
import os
import json

import tensorflow as tf
#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import mnist

per_worker_batch_size = 64
tf_config = json.loads(os.environ['TF_CONFIG'])
num_workers = len(tf_config['cluster']['worker'])

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

global_batch_size = per_worker_batch_size * num_workers
multi_worker_dataset = mnist.mnist_dataset(global_batch_size)

with strategy.scope():
  # Model building/compiling need to be within `strategy.scope()`.
  multi_worker_model = mnist.build_and_compile_cnn_model()


multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)
    cases where workers die or are otherwise unstable. You do this by preserving training state in the distributed
    file system of your choice, such that upon restart of the instance that previously failed or preempted,
    the training state is recovered. Since all the workers are kept in sync in terms of training epochs and steps, 
    other workers would need to wait for the failed or preempted worker to restart to continue.[3]
      
    """

    # get the dataset
    dataset = mnist.mnist_dataset(batch_size)
    """
    
    This automatically shards the data based on the sharding policy provided 
    by the tf distribute statergy[3]. 
    
    """

    # Model building/compiling need to be within `strategy.scope()`.
    with strategy.scope():
        model = mnist.build_and_compile_cnn_model()

    # fit the model
    model.fit(dataset, epochs=args.epochs, steps_per_epoch=args.steps)
"""

References:
    
    [1]. https://www.codingame.com/playgrounds/349/introduction-to-mpi/introduction-to-collective-communications
    [2]. https://www.tensorflow.org/guide/distributed_training
    [3]. https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras
    
"""