def main(_):

    # Initialize Horovod (TODO: Remove dependency of horovod for freezing graphs)
    hvd.init()

    if not FLAGS.output_file:
        raise ValueError(
            'You must supply the path to save to with --output_file')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default() as graph:
        if FLAGS.input_format == 'NCHW':
            input_shape = [
                FLAGS.batch_size, 3, FLAGS.image_size, FLAGS.image_size
            ]
        else:
            input_shape = [
                FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3
            ]
        input_images = tf.placeholder(name='input',
                                      dtype=tf.float32,
                                      shape=input_shape)

        resnet50_config = resnet.model_architectures[FLAGS.model_name]
        network = resnet.ResnetModel(FLAGS.model_name, FLAGS.num_classes,
                                     resnet50_config['layers'],
                                     resnet50_config['widths'],
                                     resnet50_config['expansions'],
                                     FLAGS.compute_format, FLAGS.input_format)
        probs, logits = network.build_model(
            input_images,
            training=False,
            reuse=False,
            use_final_conv=FLAGS.use_final_conv)

        if FLAGS.quantize:
            tf.contrib.quantize.experimental_create_eval_graph(
                symmetric=FLAGS.symmetric, use_qdq=FLAGS.use_qdq)

        # Define the saver and restore the checkpoint
        saver = tf.train.Saver()
        with tf.Session() as sess:
            if FLAGS.checkpoint:
                saver.restore(sess, FLAGS.checkpoint)
            else:
                sess.run(tf.global_variables_initializer())
            graph_def = graph.as_graph_def()
            frozen_graph_def = tf.graph_util.convert_variables_to_constants(
                sess, graph_def, [probs.op.name])

        # Write out the frozen graph
        tf.io.write_graph(frozen_graph_def,
                          os.path.dirname(FLAGS.output_file),
                          os.path.basename(FLAGS.output_file),
                          as_text=FLAGS.write_text_graphdef)
Example #2
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
import os
import six

import math

import multiprocessing

import tensorflow as tf
import smdistributed.dataparallel.tensorflow as hvd
hvd.init()

from mask_rcnn.utils.logging_formatter import logging

from mask_rcnn.utils.distributed_utils import MPI_is_distributed
from mask_rcnn.utils.distributed_utils import MPI_local_rank
from mask_rcnn.utils.distributed_utils import MPI_rank

from mask_rcnn.hooks.logging_hook import AutoLoggingHook

from mask_rcnn.utils.lazy_imports import LazyImport
from tensorflow.core.protobuf import rewriter_config_pb2

from mask_rcnn import evaluation
from mask_rcnn.hyperparameters import params_io
from mask_rcnn.hooks import CheckpointSaverHook
    parser.add_argument("--model_dir", type=str, default=os.environ["SM_MODEL_DIR"])
    parser.add_argument("--n_gpus", type=str, default=os.environ["SM_NUM_GPUS"])

    args, _ = parser.parse_known_args()

    # Set up logging
    logger = logging.getLogger(__name__)

    logging.basicConfig(
        level=logging.getLevelName("INFO"),
        handlers=[logging.StreamHandler(sys.stdout)],
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    )

    if SDP_ENABLED:
        sdp.init()

        gpus = tf.config.experimental.list_physical_devices("GPU")
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        if gpus:
            tf.config.experimental.set_visible_devices(gpus[sdp.local_rank()], "GPU")

    # Load model and tokenizer
    model = TFAutoModelForSequenceClassification.from_pretrained(args.model_name)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)

    # get datasets
    tf_train_dataset, tf_test_dataset = get_datasets()

    # fine optimizer and loss
Example #4
0
def main(args):
    # Hyper-parameters
    epochs = args.epochs
    lr = args.learning_rate
    batch_size = args.batch_size
    momentum = args.momentum
    weight_decay = args.weight_decay
    optimizer = args.optimizer
    model_type = args.model_type

    # SageMaker options
    training_dir = args.train
    validation_dir = args.validation
    eval_dir = args.eval

    # Change: Initialize SMDataParallel and get the size of the cluster
    smdp.init()
    size = smdp.size()

    # Change: Pin GPU to local process (one GPU per process)
    gpus = tf.config.experimental.list_physical_devices('GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    if gpus:
        # SMDataParallel: Pin GPUs to a single SMDataParallel process [use SMDataParallel local_rank() API]
        tf.config.experimental.set_visible_devices(gpus[smdp.local_rank()],
                                                   'GPU')

    # Get dataset
    train_dataset = get_dataset(training_dir + '/train.tfrecords', batch_size)
    train_dataset = train_dataset.take(NUM_TRAIN_IMAGES // size).shuffle(10000)

    val_dataset = get_dataset(validation_dir + '/validation.tfrecords',
                              batch_size)
    eval_dataset = get_dataset(eval_dir + '/eval.tfrecords', batch_size)

    # Load model
    model = get_model(model_type)

    # Optimizer
    if optimizer.lower() == 'adam':
        opt = Adam(lr=lr * size, decay=weight_decay)
    elif optimizer.lower() == 'rmsprop':
        opt = RMSprop(lr=lr * size, decay=weight_decay)
    else:
        opt = SGD(lr=lr * size, decay=weight_decay, momentum=momentum)

    # Loss function
    loss = tf.keras.losses.CategoricalCrossentropy()

    # Metrics to track
    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.CategoricalAccuracy(
        name='train_accuracy')

    val_loss = tf.keras.metrics.Mean(name='val_loss')
    val_accuracy = tf.keras.metrics.CategoricalAccuracy(name='val_accuracy')

    test_loss = tf.keras.metrics.Mean(name='test_loss')
    test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')

    # Training step
    @tf.function
    def training_step(images, labels, first_batch):
        with tf.GradientTape() as tape:
            train_pred = model(images, training=True)
            loss_value = loss(labels, train_pred)
        # Change: Wrap tf.GradientTape with SMDataParallel's DistributedGradientTape
        tape = smdp.DistributedGradientTape(tape)

        grads = tape.gradient(loss_value, model.trainable_variables)
        opt.apply_gradients(zip(grads, model.trainable_variables))

        if first_batch:
            # Change: Broadcast model and optimizer variables
            smdp.broadcast_variables(model.variables, root_rank=0)
            smdp.broadcast_variables(opt.variables(), root_rank=0)

        # Change: all_reduce call
        train_loss_value = smdp.oob_allreduce(
            loss_value)  # Average the loss across workers

        train_loss(train_loss_value)
        train_accuracy(labels, train_pred)
        return

    # Test step
    @tf.function
    def test_step(images, labels):
        val_pred = model(images, training=False)
        val_loss_value = loss(labels, val_pred)

        val_loss(val_loss_value)
        val_accuracy(labels, val_pred)
        return

    if smdp.rank() == 0:
        tb_log_dir = '/opt/ml/output/tensorboard/'
        train_summary_writer = tf.summary.create_file_writer(tb_log_dir)
        test_summary_writer = tf.summary.create_file_writer(tb_log_dir)

    # Training loop
    for epoch in range(epochs):
        train_loss.reset_states()
        train_accuracy.reset_states()
        val_loss.reset_states()
        val_accuracy.reset_states()

        for batch, (images, labels) in enumerate(train_dataset):
            start_time = time.time()
            training_step(images, labels, batch == 0)
            epoch_time = time.time() - start_time

        for images, labels in val_dataset:
            test_step(images, labels)

        if smdp.rank() == 0:
            with train_summary_writer.as_default():
                tf.summary.scalar('train_loss',
                                  train_loss.result(),
                                  step=epoch)
                tf.summary.scalar('train_accuracy',
                                  train_accuracy.result(),
                                  step=epoch)

            with test_summary_writer.as_default():
                tf.summary.scalar('val_loss', val_loss.result(), step=epoch)
                tf.summary.scalar('val_accuracy',
                                  val_accuracy.result(),
                                  step=epoch)

            print(
                f'Epoch: {epoch + 1}, '
                f'Epoch duration: {epoch_time} sec, '
                f'Training loss: {train_loss.result()}, '
                f'Training accuracy: {train_accuracy.result() * 100}',
                f'Validation Loss: {val_loss.result()}, '
                f'Validation Accuracy: {val_accuracy.result() * 100}')

    for images, labels in eval_dataset:
        test_pred = model(images, training=False)
        test_loss_value = loss(labels, test_pred)

        test_loss(test_loss_value)
        test_accuracy(labels, test_pred)

    print('====== Test Results ======')
    print(f'Test loss: {test_loss.result()}, '
          f'Test accuracy: {test_accuracy.result() * 100}')
    print('====== End of training ======')

    # Change: Save checkpoints only from master node.
    if smdp.rank() == 0:
        model.save(os.path.join(os.environ["SM_MODEL_DIR"], '1'))
# Copyright 2018 Uber Technologies, Inc. All Rights Reserved.
# Modifications Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and limitations under the License.

import tensorflow as tf
tf.random.set_seed(42)
import smdistributed.dataparallel.tensorflow as dist

dist.init()

gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
if gpus:
    tf.config.experimental.set_visible_devices(gpus[dist.local_rank()], 'GPU')

(mnist_images, mnist_labels), _ = \
    tf.keras.datasets.mnist.load_data(path='mnist-%d.npz' % dist.rank())

dataset = tf.data.Dataset.from_tensor_slices(
    (tf.cast(mnist_images[..., tf.newaxis] / 255.0,
             tf.float32), tf.cast(mnist_labels, tf.int64)))
dataset = dataset.repeat().shuffle(10000).batch(128)

mnist_model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, [3, 3], activation='relu'),
Example #6
0
from common.datasets import get_dataset_from_tfrecords
from common.models import create_model
from common.optimizers import get_adamw_optimizer, get_lamb_optimizer
from common.utils import (
    TqdmLoggingHandler,
    create_tokenizer,
    gather_indexes,
    is_wandb_available,
    rewrap_tf_function,
)

# See https://github.com/huggingface/transformers/issues/3782; this import must come last
import smdistributed.dataparallel.tensorflow as smddp  # isort:skip
import smddpcommon as hc

smddp.init()

if is_wandb_available():
    import wandb

# Exclude Pack operator in XLA for 2.4
if not _PRE_TF_2_4_0:
    os.environ['TF_XLA_FLAGS'] = "--tf_xla_auto_jit=1 --tf_xla_ops_to_cluster=Add,AddN,AddV2,All,ArgMax,AssignAddVariableOp,AssignVariableOp,BatchMatMulV2,BiasAdd,BiasAddGrad,Cast,ConcatV2,Const,Equal,Erf,Exp,ExpandDims,GatherV2,Greater,GreaterEqual,Identity,IdentityN,If,IsFinite,L2Loss,LessEqual,MatMul,Maximum,Mean,Minimum,Mul,Neg,NoOp,PartitionedCall,Pow,RandomUniform,ReadVariableOp,RealDiv,Reciprocal,Reshape,ResourceGather,Rsqrt,RsqrtGrad,SelectV2,Softmax,SparseSoftmaxCrossEntropyWithLogits,Sqrt,Square,SquaredDifference,Squeeze,StatelessIf,StridedSlice,StridedSliceGrad,Sub,Sum,Tanh,TanhGrad,Tile,Transpose,UnsortedSegmentSum,VariableShape"
    tf.keras.mixed_precision.set_global_policy('mixed_float16')

logger = logging.getLogger(__name__)


def mlm_loss_fn(
    prediction_logits: "[batch, max_seq_len (512), vocab_size]",
    label_positions: "[batch, num_masks (20)]",
def main(_):
  os.environ["TF_XLA_FLAGS"] = "--tf_xla_enable_lazy_compilation=false" #causes memory fragmentation for bert leading to OOM

  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
  dllogging = utils.dllogger_class.dllogger_class(FLAGS.dllog_path)

  if not FLAGS.do_train and not FLAGS.do_eval:
    raise ValueError("At least one of `do_train` or `do_eval` must be True.")

  # Set seed to reduce randomness
  random.seed(FLAGS.seed)
  np.random.seed(FLAGS.seed)
  tf.set_random_seed(FLAGS.seed)

  if FLAGS.herring:
    import smdistributed.dataparallel.tensorflow as hvd
    hvd.init()

  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

  tf.io.gfile.makedirs(FLAGS.output_dir)

  input_files = []
  for input_file_dir in FLAGS.input_files_dir.split(","):
    input_files.extend(tf.io.gfile.glob(os.path.join(input_file_dir, "*")))

  if FLAGS.herring and len(input_files) < hvd.size():
      raise ValueError("Input Files must be sharded")
  if FLAGS.amp and FLAGS.manual_fp16:
      raise ValueError("AMP and Manual Mixed Precision Training are both activated! Error")

  is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
  config = tf.compat.v1.ConfigProto()
  if FLAGS.herring:
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    if hvd.rank() == 0:
      tf.compat.v1.logging.info("***** Configuaration *****")
      for key in FLAGS.__flags.keys():
          tf.compat.v1.logging.info('  {}: {}'.format(key, getattr(FLAGS, key)))
      tf.compat.v1.logging.info("**************************")

#    config.gpu_options.per_process_gpu_memory_fraction = 0.7
  if FLAGS.use_xla: 
      config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1
      config.graph_options.rewrite_options.memory_optimization = rewriter_config_pb2.RewriterConfig.NO_MEM_OPT
      if FLAGS.amp:
        tf.enable_resource_variables()

  run_config = tf.estimator.RunConfig(
      tf_random_seed=(FLAGS.seed if not FLAGS.herring else (FLAGS.seed + hvd.rank())),
      model_dir=FLAGS.output_dir,
      session_config=config,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps if not FLAGS.herring or hvd.rank() == 0 else None,
      save_summary_steps=FLAGS.save_checkpoints_steps if not FLAGS.herring or hvd.rank() == 0 else None,
      # This variable controls how often estimator reports examples/sec.
      # Default value is every 100 steps.
      # When --report_loss is True, we set to very large value to prevent
      # default info reporting from estimator.
      # Ideally we should set it to None, but that does not work.
      log_step_count_steps=10000 if FLAGS.report_loss else 100)

  model_fn = model_fn_builder(
      bert_config=bert_config,
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate if not FLAGS.herring else FLAGS.learning_rate*hvd.size(),
      num_train_steps=FLAGS.num_train_steps,
      num_warmup_steps=FLAGS.num_warmup_steps,
      use_one_hot_embeddings=False,
      hvd=None if not FLAGS.herring else hvd)

  estimator = tf.estimator.Estimator(
      model_fn=model_fn,
      config=run_config)

  if FLAGS.do_train:

    training_hooks = []
    if FLAGS.herring and hvd.size() > 1:
      training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))
    if (not FLAGS.herring or hvd.rank() == 0):
      global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps if not FLAGS.herring else FLAGS.train_batch_size * FLAGS.num_accumulation_steps * hvd.size()
      training_hooks.append(_LogSessionRunHook(global_batch_size, FLAGS.num_accumulation_steps, dllogging, FLAGS.display_loss_steps, FLAGS.save_checkpoints_steps, FLAGS.report_loss))

    tf.compat.v1.logging.info("***** Running training *****")
    tf.compat.v1.logging.info("  Batch size = %d", FLAGS.train_batch_size)
    train_input_fn = input_fn_builder(
        input_files=input_files,
        batch_size=FLAGS.train_batch_size,
        max_seq_length=FLAGS.max_seq_length,
        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        is_training=True,
        hvd=None if not FLAGS.herring else hvd)

    train_start_time = time.time()
    estimator.train(input_fn=train_input_fn, hooks=training_hooks, max_steps=FLAGS.num_train_steps)
    train_time_elapsed = time.time() - train_start_time

    if (not FLAGS.herring or hvd.rank() == 0):
        train_time_wo_overhead = training_hooks[-1].total_time
        avg_sentences_per_second = FLAGS.num_train_steps * global_batch_size * 1.0 / train_time_elapsed
        ss_sentences_per_second = (FLAGS.num_train_steps - training_hooks[-1].skipped) * global_batch_size * 1.0 / train_time_wo_overhead

        tf.compat.v1.logging.info("-----------------------------")
        tf.compat.v1.logging.info("Total Training Time = %0.2f for Sentences = %d", train_time_elapsed,
                        FLAGS.num_train_steps * global_batch_size)
        tf.compat.v1.logging.info("Total Training Time W/O Overhead = %0.2f for Sentences = %d", train_time_wo_overhead,
                        (FLAGS.num_train_steps - training_hooks[-1].skipped) * global_batch_size)
        tf.compat.v1.logging.info("Training Throughput Average (sentences/sec) with overhead = %0.2f", avg_sentences_per_second)
        tf.compat.v1.logging.info("Training Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
        dllogging.logger.log(step=(), data={"throughput_train": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
        tf.compat.v1.logging.info("-----------------------------")

  if FLAGS.do_eval and (not FLAGS.herring or hvd.rank() == 0):
    tf.compat.v1.logging.info("***** Running evaluation *****")
    tf.compat.v1.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

    eval_files = []
    for eval_file_dir in FLAGS.eval_files_dir.split(","):
        eval_files.extend(tf.io.gfile.glob(os.path.join(eval_file_dir, "*")))

    eval_input_fn = input_fn_builder(
        input_files=eval_files,
        batch_size=FLAGS.eval_batch_size,
        max_seq_length=FLAGS.max_seq_length,
        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        is_training=False,
        hvd=None if not FLAGS.herring else hvd)

    eval_hooks = [LogEvalRunHook(FLAGS.eval_batch_size)]
    eval_start_time = time.time()
    result = estimator.evaluate(
        input_fn=eval_input_fn, steps=FLAGS.max_eval_steps, hooks=eval_hooks)

    eval_time_elapsed = time.time() - eval_start_time
    time_list = eval_hooks[-1].time_list
    time_list.sort()
    # Removing outliers (init/warmup) in throughput computation.
    eval_time_wo_overhead = sum(time_list[:int(len(time_list) * 0.99)])
    num_sentences = (int(len(time_list) * 0.99)) * FLAGS.eval_batch_size

    ss_sentences_per_second = num_sentences * 1.0 / eval_time_wo_overhead

    tf.compat.v1.logging.info("-----------------------------")
    tf.compat.v1.logging.info("Total Inference Time = %0.2f for Sentences = %d", eval_time_elapsed,
                    eval_hooks[-1].count * FLAGS.eval_batch_size)
    tf.compat.v1.logging.info("Total Inference Time W/O Overhead = %0.2f for Sentences = %d", eval_time_wo_overhead,
                    num_sentences)
    tf.compat.v1.logging.info("Summary Inference Statistics on EVAL set")
    tf.compat.v1.logging.info("Batch size = %d", FLAGS.eval_batch_size)
    tf.compat.v1.logging.info("Sequence Length = %d", FLAGS.max_seq_length)
    tf.compat.v1.logging.info("Precision = %s", "fp16" if FLAGS.amp else "fp32")
    tf.compat.v1.logging.info("Inference Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
    dllogging.logger.log(step=(), data={"throughput_val": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
    tf.compat.v1.logging.info("-----------------------------")

    output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
    with tf.io.gfile.GFile(output_eval_file, "w") as writer:
      tf.compat.v1.logging.info("***** Eval results *****")
      for key in sorted(result.keys()):
        tf.compat.v1.logging.info("  %s = %s", key, str(result[key]))
        writer.write("%s = %s\n" % (key, str(result[key])))
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  if FLAGS.amp:
      os.environ["TF_ENABLE_AUTO_MIXED_PRECISION"] = "1"
  else:
      os.environ["TF_ENABLE_AUTO_MIXED_PRECISION"] = "0"

  # Set seed to reduce randomness
  np.random.seed(FLAGS.seed)
  tf.set_random_seed(FLAGS.seed)

  hvd.init()

  flags.mark_flag_as_required('model_dir')
  flags.mark_flag_as_required('pipeline_config_path')
  session_config = tf.ConfigProto()
  session_config.gpu_options.per_process_gpu_memory_fraction=0.9
  session_config.gpu_options.visible_device_list = str(hvd.local_rank())
  if FLAGS.allow_xla:
      session_config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
  model_dir = FLAGS.model_dir if hvd.rank() == 0 else None
  config = tf.estimator.RunConfig(tf_random_seed=(FLAGS.seed + hvd.rank()),
                                  model_dir=model_dir, session_config=session_config)

  train_and_eval_dict = model_lib.create_estimator_and_inputs(
      run_config=config,
      eval_count=FLAGS.eval_count,
      hparams=model_hparams.create_hparams(FLAGS.hparams_overrides),
      pipeline_config_path=FLAGS.pipeline_config_path,
      train_steps=FLAGS.num_train_steps,
      sample_1_of_n_eval_examples=FLAGS.sample_1_of_n_eval_examples,
      sample_1_of_n_eval_on_train_examples=(
          FLAGS.sample_1_of_n_eval_on_train_examples))
  estimator = train_and_eval_dict['estimator']
  train_input_fn = train_and_eval_dict['train_input_fn']
  eval_input_fns = train_and_eval_dict['eval_input_fns']
  eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
  predict_input_fn = train_and_eval_dict['predict_input_fn']
  train_steps = train_and_eval_dict['train_steps']

  if FLAGS.checkpoint_dir:
    if FLAGS.eval_training_data:
      name = 'training_data'
      input_fn = eval_on_train_input_fn
    else:
      name = 'validation_data'
      # The first eval input will be evaluated.
      input_fn = eval_input_fns[0]
    if FLAGS.run_once:
      estimator.evaluate(input_fn,
                         steps=None,
                         checkpoint_path=tf.train.latest_checkpoint(
                             FLAGS.checkpoint_dir))
    else:
      model_lib.continuous_eval(estimator, FLAGS.checkpoint_dir, input_fn,
                                train_steps, name)
  else:
    train_spec, eval_specs = model_lib.create_train_and_eval_specs(
        train_input_fn,
        eval_input_fns,
        eval_on_train_input_fn,
        predict_input_fn,
        train_steps,
        eval_on_train_data=False)

    train_hooks = [hvd.BroadcastGlobalVariablesHook(0), DLLoggerHook(hvd.size()*train_and_eval_dict['train_batch_size'], hvd.rank())]
    eval_hooks = []

    for x in range(FLAGS.eval_count):
        estimator.train(train_input_fn,
                        hooks=train_hooks,
                        steps=train_steps // FLAGS.eval_count)


        if hvd.rank() == 0 and not FLAGS.train_only:
            eval_input_fn = eval_input_fns[0]
            results = estimator.evaluate(eval_input_fn,
                               steps=None,
                               hooks=eval_hooks)
Example #9
0
# under the License.

import tensorflow as tf
import argparse
import os

########################################################
####### 1. SageMaker Distributed Data Parallel  ########
#######  - Import Package and Initialization    ########
########################################################

# Import SMDataParallel TensorFlow2 Modules
import smdistributed.dataparallel.tensorflow as smdp

# SMDataParallel: Initialize
smdp.init()

#######################################################


def train(args):
    tf.random.set_seed(args.seed)

    gpus = tf.config.experimental.list_physical_devices('GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

    if gpus:
        # Pin GPUs to a single SMDataParallel process [use SMDataParallel local_rank() API]
        tf.config.experimental.set_visible_devices(gpus[args.local_rank],
                                                   'GPU')
Example #10
0
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and limitations under the License.

# Third Party
import smdistributed.dataparallel.tensorflow as smdataparallel
import tensorflow as tf

# Register smdataparallel shutdown hook
smdataparallel.init()

gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
if gpus:
    tf.config.experimental.set_visible_devices(gpus[smdataparallel.local_rank()], "GPU")

(mnist_images, mnist_labels), _ = tf.keras.datasets.mnist.load_data(
    path="mnist-%d.npz" % smdataparallel.rank()
)

dataset = tf.data.Dataset.from_tensor_slices(
    (tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32), tf.cast(mnist_labels, tf.int64))
)
dataset = dataset.repeat().shuffle(10000).batch(128)

mnist_model = tf.keras.Sequential(