def main(_): # Initialize Horovod (TODO: Remove dependency of horovod for freezing graphs) hvd.init() if not FLAGS.output_file: raise ValueError( 'You must supply the path to save to with --output_file') tf.logging.set_verbosity(tf.logging.INFO) with tf.Graph().as_default() as graph: if FLAGS.input_format == 'NCHW': input_shape = [ FLAGS.batch_size, 3, FLAGS.image_size, FLAGS.image_size ] else: input_shape = [ FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3 ] input_images = tf.placeholder(name='input', dtype=tf.float32, shape=input_shape) resnet50_config = resnet.model_architectures[FLAGS.model_name] network = resnet.ResnetModel(FLAGS.model_name, FLAGS.num_classes, resnet50_config['layers'], resnet50_config['widths'], resnet50_config['expansions'], FLAGS.compute_format, FLAGS.input_format) probs, logits = network.build_model( input_images, training=False, reuse=False, use_final_conv=FLAGS.use_final_conv) if FLAGS.quantize: tf.contrib.quantize.experimental_create_eval_graph( symmetric=FLAGS.symmetric, use_qdq=FLAGS.use_qdq) # Define the saver and restore the checkpoint saver = tf.train.Saver() with tf.Session() as sess: if FLAGS.checkpoint: saver.restore(sess, FLAGS.checkpoint) else: sess.run(tf.global_variables_initializer()) graph_def = graph.as_graph_def() frozen_graph_def = tf.graph_util.convert_variables_to_constants( sess, graph_def, [probs.op.name]) # Write out the frozen graph tf.io.write_graph(frozen_graph_def, os.path.dirname(FLAGS.output_file), os.path.basename(FLAGS.output_file), as_text=FLAGS.write_text_graphdef)
from __future__ import absolute_import from __future__ import division from __future__ import print_function import abc import os import six import math import multiprocessing import tensorflow as tf import smdistributed.dataparallel.tensorflow as hvd hvd.init() from mask_rcnn.utils.logging_formatter import logging from mask_rcnn.utils.distributed_utils import MPI_is_distributed from mask_rcnn.utils.distributed_utils import MPI_local_rank from mask_rcnn.utils.distributed_utils import MPI_rank from mask_rcnn.hooks.logging_hook import AutoLoggingHook from mask_rcnn.utils.lazy_imports import LazyImport from tensorflow.core.protobuf import rewriter_config_pb2 from mask_rcnn import evaluation from mask_rcnn.hyperparameters import params_io from mask_rcnn.hooks import CheckpointSaverHook
parser.add_argument("--model_dir", type=str, default=os.environ["SM_MODEL_DIR"]) parser.add_argument("--n_gpus", type=str, default=os.environ["SM_NUM_GPUS"]) args, _ = parser.parse_known_args() # Set up logging logger = logging.getLogger(__name__) logging.basicConfig( level=logging.getLevelName("INFO"), handlers=[logging.StreamHandler(sys.stdout)], format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) if SDP_ENABLED: sdp.init() gpus = tf.config.experimental.list_physical_devices("GPU") for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpus: tf.config.experimental.set_visible_devices(gpus[sdp.local_rank()], "GPU") # Load model and tokenizer model = TFAutoModelForSequenceClassification.from_pretrained(args.model_name) tokenizer = AutoTokenizer.from_pretrained(args.model_name) # get datasets tf_train_dataset, tf_test_dataset = get_datasets() # fine optimizer and loss
def main(args): # Hyper-parameters epochs = args.epochs lr = args.learning_rate batch_size = args.batch_size momentum = args.momentum weight_decay = args.weight_decay optimizer = args.optimizer model_type = args.model_type # SageMaker options training_dir = args.train validation_dir = args.validation eval_dir = args.eval # Change: Initialize SMDataParallel and get the size of the cluster smdp.init() size = smdp.size() # Change: Pin GPU to local process (one GPU per process) gpus = tf.config.experimental.list_physical_devices('GPU') for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpus: # SMDataParallel: Pin GPUs to a single SMDataParallel process [use SMDataParallel local_rank() API] tf.config.experimental.set_visible_devices(gpus[smdp.local_rank()], 'GPU') # Get dataset train_dataset = get_dataset(training_dir + '/train.tfrecords', batch_size) train_dataset = train_dataset.take(NUM_TRAIN_IMAGES // size).shuffle(10000) val_dataset = get_dataset(validation_dir + '/validation.tfrecords', batch_size) eval_dataset = get_dataset(eval_dir + '/eval.tfrecords', batch_size) # Load model model = get_model(model_type) # Optimizer if optimizer.lower() == 'adam': opt = Adam(lr=lr * size, decay=weight_decay) elif optimizer.lower() == 'rmsprop': opt = RMSprop(lr=lr * size, decay=weight_decay) else: opt = SGD(lr=lr * size, decay=weight_decay, momentum=momentum) # Loss function loss = tf.keras.losses.CategoricalCrossentropy() # Metrics to track train_loss = tf.keras.metrics.Mean(name='train_loss') train_accuracy = tf.keras.metrics.CategoricalAccuracy( name='train_accuracy') val_loss = tf.keras.metrics.Mean(name='val_loss') val_accuracy = tf.keras.metrics.CategoricalAccuracy(name='val_accuracy') test_loss = tf.keras.metrics.Mean(name='test_loss') test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy') # Training step @tf.function def training_step(images, labels, first_batch): with tf.GradientTape() as tape: train_pred = model(images, training=True) loss_value = loss(labels, train_pred) # Change: Wrap tf.GradientTape with SMDataParallel's DistributedGradientTape tape = smdp.DistributedGradientTape(tape) grads = tape.gradient(loss_value, model.trainable_variables) opt.apply_gradients(zip(grads, model.trainable_variables)) if first_batch: # Change: Broadcast model and optimizer variables smdp.broadcast_variables(model.variables, root_rank=0) smdp.broadcast_variables(opt.variables(), root_rank=0) # Change: all_reduce call train_loss_value = smdp.oob_allreduce( loss_value) # Average the loss across workers train_loss(train_loss_value) train_accuracy(labels, train_pred) return # Test step @tf.function def test_step(images, labels): val_pred = model(images, training=False) val_loss_value = loss(labels, val_pred) val_loss(val_loss_value) val_accuracy(labels, val_pred) return if smdp.rank() == 0: tb_log_dir = '/opt/ml/output/tensorboard/' train_summary_writer = tf.summary.create_file_writer(tb_log_dir) test_summary_writer = tf.summary.create_file_writer(tb_log_dir) # Training loop for epoch in range(epochs): train_loss.reset_states() train_accuracy.reset_states() val_loss.reset_states() val_accuracy.reset_states() for batch, (images, labels) in enumerate(train_dataset): start_time = time.time() training_step(images, labels, batch == 0) epoch_time = time.time() - start_time for images, labels in val_dataset: test_step(images, labels) if smdp.rank() == 0: with train_summary_writer.as_default(): tf.summary.scalar('train_loss', train_loss.result(), step=epoch) tf.summary.scalar('train_accuracy', train_accuracy.result(), step=epoch) with test_summary_writer.as_default(): tf.summary.scalar('val_loss', val_loss.result(), step=epoch) tf.summary.scalar('val_accuracy', val_accuracy.result(), step=epoch) print( f'Epoch: {epoch + 1}, ' f'Epoch duration: {epoch_time} sec, ' f'Training loss: {train_loss.result()}, ' f'Training accuracy: {train_accuracy.result() * 100}', f'Validation Loss: {val_loss.result()}, ' f'Validation Accuracy: {val_accuracy.result() * 100}') for images, labels in eval_dataset: test_pred = model(images, training=False) test_loss_value = loss(labels, test_pred) test_loss(test_loss_value) test_accuracy(labels, test_pred) print('====== Test Results ======') print(f'Test loss: {test_loss.result()}, ' f'Test accuracy: {test_accuracy.result() * 100}') print('====== End of training ======') # Change: Save checkpoints only from master node. if smdp.rank() == 0: model.save(os.path.join(os.environ["SM_MODEL_DIR"], '1'))
# Copyright 2018 Uber Technologies, Inc. All Rights Reserved. # Modifications Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and limitations under the License. import tensorflow as tf tf.random.set_seed(42) import smdistributed.dataparallel.tensorflow as dist dist.init() gpus = tf.config.experimental.list_physical_devices('GPU') for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpus: tf.config.experimental.set_visible_devices(gpus[dist.local_rank()], 'GPU') (mnist_images, mnist_labels), _ = \ tf.keras.datasets.mnist.load_data(path='mnist-%d.npz' % dist.rank()) dataset = tf.data.Dataset.from_tensor_slices( (tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32), tf.cast(mnist_labels, tf.int64))) dataset = dataset.repeat().shuffle(10000).batch(128) mnist_model = tf.keras.Sequential([ tf.keras.layers.Conv2D(32, [3, 3], activation='relu'),
from common.datasets import get_dataset_from_tfrecords from common.models import create_model from common.optimizers import get_adamw_optimizer, get_lamb_optimizer from common.utils import ( TqdmLoggingHandler, create_tokenizer, gather_indexes, is_wandb_available, rewrap_tf_function, ) # See https://github.com/huggingface/transformers/issues/3782; this import must come last import smdistributed.dataparallel.tensorflow as smddp # isort:skip import smddpcommon as hc smddp.init() if is_wandb_available(): import wandb # Exclude Pack operator in XLA for 2.4 if not _PRE_TF_2_4_0: os.environ['TF_XLA_FLAGS'] = "--tf_xla_auto_jit=1 --tf_xla_ops_to_cluster=Add,AddN,AddV2,All,ArgMax,AssignAddVariableOp,AssignVariableOp,BatchMatMulV2,BiasAdd,BiasAddGrad,Cast,ConcatV2,Const,Equal,Erf,Exp,ExpandDims,GatherV2,Greater,GreaterEqual,Identity,IdentityN,If,IsFinite,L2Loss,LessEqual,MatMul,Maximum,Mean,Minimum,Mul,Neg,NoOp,PartitionedCall,Pow,RandomUniform,ReadVariableOp,RealDiv,Reciprocal,Reshape,ResourceGather,Rsqrt,RsqrtGrad,SelectV2,Softmax,SparseSoftmaxCrossEntropyWithLogits,Sqrt,Square,SquaredDifference,Squeeze,StatelessIf,StridedSlice,StridedSliceGrad,Sub,Sum,Tanh,TanhGrad,Tile,Transpose,UnsortedSegmentSum,VariableShape" tf.keras.mixed_precision.set_global_policy('mixed_float16') logger = logging.getLogger(__name__) def mlm_loss_fn( prediction_logits: "[batch, max_seq_len (512), vocab_size]", label_positions: "[batch, num_masks (20)]",
def main(_): os.environ["TF_XLA_FLAGS"] = "--tf_xla_enable_lazy_compilation=false" #causes memory fragmentation for bert leading to OOM tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) dllogging = utils.dllogger_class.dllogger_class(FLAGS.dllog_path) if not FLAGS.do_train and not FLAGS.do_eval: raise ValueError("At least one of `do_train` or `do_eval` must be True.") # Set seed to reduce randomness random.seed(FLAGS.seed) np.random.seed(FLAGS.seed) tf.set_random_seed(FLAGS.seed) if FLAGS.herring: import smdistributed.dataparallel.tensorflow as hvd hvd.init() bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) tf.io.gfile.makedirs(FLAGS.output_dir) input_files = [] for input_file_dir in FLAGS.input_files_dir.split(","): input_files.extend(tf.io.gfile.glob(os.path.join(input_file_dir, "*"))) if FLAGS.herring and len(input_files) < hvd.size(): raise ValueError("Input Files must be sharded") if FLAGS.amp and FLAGS.manual_fp16: raise ValueError("AMP and Manual Mixed Precision Training are both activated! Error") is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 config = tf.compat.v1.ConfigProto() if FLAGS.herring: config.gpu_options.visible_device_list = str(hvd.local_rank()) if hvd.rank() == 0: tf.compat.v1.logging.info("***** Configuaration *****") for key in FLAGS.__flags.keys(): tf.compat.v1.logging.info(' {}: {}'.format(key, getattr(FLAGS, key))) tf.compat.v1.logging.info("**************************") # config.gpu_options.per_process_gpu_memory_fraction = 0.7 if FLAGS.use_xla: config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1 config.graph_options.rewrite_options.memory_optimization = rewriter_config_pb2.RewriterConfig.NO_MEM_OPT if FLAGS.amp: tf.enable_resource_variables() run_config = tf.estimator.RunConfig( tf_random_seed=(FLAGS.seed if not FLAGS.herring else (FLAGS.seed + hvd.rank())), model_dir=FLAGS.output_dir, session_config=config, save_checkpoints_steps=FLAGS.save_checkpoints_steps if not FLAGS.herring or hvd.rank() == 0 else None, save_summary_steps=FLAGS.save_checkpoints_steps if not FLAGS.herring or hvd.rank() == 0 else None, # This variable controls how often estimator reports examples/sec. # Default value is every 100 steps. # When --report_loss is True, we set to very large value to prevent # default info reporting from estimator. # Ideally we should set it to None, but that does not work. log_step_count_steps=10000 if FLAGS.report_loss else 100) model_fn = model_fn_builder( bert_config=bert_config, init_checkpoint=FLAGS.init_checkpoint, learning_rate=FLAGS.learning_rate if not FLAGS.herring else FLAGS.learning_rate*hvd.size(), num_train_steps=FLAGS.num_train_steps, num_warmup_steps=FLAGS.num_warmup_steps, use_one_hot_embeddings=False, hvd=None if not FLAGS.herring else hvd) estimator = tf.estimator.Estimator( model_fn=model_fn, config=run_config) if FLAGS.do_train: training_hooks = [] if FLAGS.herring and hvd.size() > 1: training_hooks.append(hvd.BroadcastGlobalVariablesHook(0)) if (not FLAGS.herring or hvd.rank() == 0): global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps if not FLAGS.herring else FLAGS.train_batch_size * FLAGS.num_accumulation_steps * hvd.size() training_hooks.append(_LogSessionRunHook(global_batch_size, FLAGS.num_accumulation_steps, dllogging, FLAGS.display_loss_steps, FLAGS.save_checkpoints_steps, FLAGS.report_loss)) tf.compat.v1.logging.info("***** Running training *****") tf.compat.v1.logging.info(" Batch size = %d", FLAGS.train_batch_size) train_input_fn = input_fn_builder( input_files=input_files, batch_size=FLAGS.train_batch_size, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, is_training=True, hvd=None if not FLAGS.herring else hvd) train_start_time = time.time() estimator.train(input_fn=train_input_fn, hooks=training_hooks, max_steps=FLAGS.num_train_steps) train_time_elapsed = time.time() - train_start_time if (not FLAGS.herring or hvd.rank() == 0): train_time_wo_overhead = training_hooks[-1].total_time avg_sentences_per_second = FLAGS.num_train_steps * global_batch_size * 1.0 / train_time_elapsed ss_sentences_per_second = (FLAGS.num_train_steps - training_hooks[-1].skipped) * global_batch_size * 1.0 / train_time_wo_overhead tf.compat.v1.logging.info("-----------------------------") tf.compat.v1.logging.info("Total Training Time = %0.2f for Sentences = %d", train_time_elapsed, FLAGS.num_train_steps * global_batch_size) tf.compat.v1.logging.info("Total Training Time W/O Overhead = %0.2f for Sentences = %d", train_time_wo_overhead, (FLAGS.num_train_steps - training_hooks[-1].skipped) * global_batch_size) tf.compat.v1.logging.info("Training Throughput Average (sentences/sec) with overhead = %0.2f", avg_sentences_per_second) tf.compat.v1.logging.info("Training Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second) dllogging.logger.log(step=(), data={"throughput_train": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT) tf.compat.v1.logging.info("-----------------------------") if FLAGS.do_eval and (not FLAGS.herring or hvd.rank() == 0): tf.compat.v1.logging.info("***** Running evaluation *****") tf.compat.v1.logging.info(" Batch size = %d", FLAGS.eval_batch_size) eval_files = [] for eval_file_dir in FLAGS.eval_files_dir.split(","): eval_files.extend(tf.io.gfile.glob(os.path.join(eval_file_dir, "*"))) eval_input_fn = input_fn_builder( input_files=eval_files, batch_size=FLAGS.eval_batch_size, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, is_training=False, hvd=None if not FLAGS.herring else hvd) eval_hooks = [LogEvalRunHook(FLAGS.eval_batch_size)] eval_start_time = time.time() result = estimator.evaluate( input_fn=eval_input_fn, steps=FLAGS.max_eval_steps, hooks=eval_hooks) eval_time_elapsed = time.time() - eval_start_time time_list = eval_hooks[-1].time_list time_list.sort() # Removing outliers (init/warmup) in throughput computation. eval_time_wo_overhead = sum(time_list[:int(len(time_list) * 0.99)]) num_sentences = (int(len(time_list) * 0.99)) * FLAGS.eval_batch_size ss_sentences_per_second = num_sentences * 1.0 / eval_time_wo_overhead tf.compat.v1.logging.info("-----------------------------") tf.compat.v1.logging.info("Total Inference Time = %0.2f for Sentences = %d", eval_time_elapsed, eval_hooks[-1].count * FLAGS.eval_batch_size) tf.compat.v1.logging.info("Total Inference Time W/O Overhead = %0.2f for Sentences = %d", eval_time_wo_overhead, num_sentences) tf.compat.v1.logging.info("Summary Inference Statistics on EVAL set") tf.compat.v1.logging.info("Batch size = %d", FLAGS.eval_batch_size) tf.compat.v1.logging.info("Sequence Length = %d", FLAGS.max_seq_length) tf.compat.v1.logging.info("Precision = %s", "fp16" if FLAGS.amp else "fp32") tf.compat.v1.logging.info("Inference Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second) dllogging.logger.log(step=(), data={"throughput_val": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT) tf.compat.v1.logging.info("-----------------------------") output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") with tf.io.gfile.GFile(output_eval_file, "w") as writer: tf.compat.v1.logging.info("***** Eval results *****") for key in sorted(result.keys()): tf.compat.v1.logging.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key])))
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) if FLAGS.amp: os.environ["TF_ENABLE_AUTO_MIXED_PRECISION"] = "1" else: os.environ["TF_ENABLE_AUTO_MIXED_PRECISION"] = "0" # Set seed to reduce randomness np.random.seed(FLAGS.seed) tf.set_random_seed(FLAGS.seed) hvd.init() flags.mark_flag_as_required('model_dir') flags.mark_flag_as_required('pipeline_config_path') session_config = tf.ConfigProto() session_config.gpu_options.per_process_gpu_memory_fraction=0.9 session_config.gpu_options.visible_device_list = str(hvd.local_rank()) if FLAGS.allow_xla: session_config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 model_dir = FLAGS.model_dir if hvd.rank() == 0 else None config = tf.estimator.RunConfig(tf_random_seed=(FLAGS.seed + hvd.rank()), model_dir=model_dir, session_config=session_config) train_and_eval_dict = model_lib.create_estimator_and_inputs( run_config=config, eval_count=FLAGS.eval_count, hparams=model_hparams.create_hparams(FLAGS.hparams_overrides), pipeline_config_path=FLAGS.pipeline_config_path, train_steps=FLAGS.num_train_steps, sample_1_of_n_eval_examples=FLAGS.sample_1_of_n_eval_examples, sample_1_of_n_eval_on_train_examples=( FLAGS.sample_1_of_n_eval_on_train_examples)) estimator = train_and_eval_dict['estimator'] train_input_fn = train_and_eval_dict['train_input_fn'] eval_input_fns = train_and_eval_dict['eval_input_fns'] eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn'] predict_input_fn = train_and_eval_dict['predict_input_fn'] train_steps = train_and_eval_dict['train_steps'] if FLAGS.checkpoint_dir: if FLAGS.eval_training_data: name = 'training_data' input_fn = eval_on_train_input_fn else: name = 'validation_data' # The first eval input will be evaluated. input_fn = eval_input_fns[0] if FLAGS.run_once: estimator.evaluate(input_fn, steps=None, checkpoint_path=tf.train.latest_checkpoint( FLAGS.checkpoint_dir)) else: model_lib.continuous_eval(estimator, FLAGS.checkpoint_dir, input_fn, train_steps, name) else: train_spec, eval_specs = model_lib.create_train_and_eval_specs( train_input_fn, eval_input_fns, eval_on_train_input_fn, predict_input_fn, train_steps, eval_on_train_data=False) train_hooks = [hvd.BroadcastGlobalVariablesHook(0), DLLoggerHook(hvd.size()*train_and_eval_dict['train_batch_size'], hvd.rank())] eval_hooks = [] for x in range(FLAGS.eval_count): estimator.train(train_input_fn, hooks=train_hooks, steps=train_steps // FLAGS.eval_count) if hvd.rank() == 0 and not FLAGS.train_only: eval_input_fn = eval_input_fns[0] results = estimator.evaluate(eval_input_fn, steps=None, hooks=eval_hooks)
# under the License. import tensorflow as tf import argparse import os ######################################################## ####### 1. SageMaker Distributed Data Parallel ######## ####### - Import Package and Initialization ######## ######################################################## # Import SMDataParallel TensorFlow2 Modules import smdistributed.dataparallel.tensorflow as smdp # SMDataParallel: Initialize smdp.init() ####################################################### def train(args): tf.random.set_seed(args.seed) gpus = tf.config.experimental.list_physical_devices('GPU') for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpus: # Pin GPUs to a single SMDataParallel process [use SMDataParallel local_rank() API] tf.config.experimental.set_visible_devices(gpus[args.local_rank], 'GPU')
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and limitations under the License. # Third Party import smdistributed.dataparallel.tensorflow as smdataparallel import tensorflow as tf # Register smdataparallel shutdown hook smdataparallel.init() gpus = tf.config.experimental.list_physical_devices("GPU") for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpus: tf.config.experimental.set_visible_devices(gpus[smdataparallel.local_rank()], "GPU") (mnist_images, mnist_labels), _ = tf.keras.datasets.mnist.load_data( path="mnist-%d.npz" % smdataparallel.rank() ) dataset = tf.data.Dataset.from_tensor_slices( (tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32), tf.cast(mnist_labels, tf.int64)) ) dataset = dataset.repeat().shuffle(10000).batch(128) mnist_model = tf.keras.Sequential(