def main(args): samples, labels = restore_from_file(args.filename) dataset = tf_dataset(samples, labels, batchsize=args.global_batch_size, to_sparse_tensor=False, repeat=1) dataset = dataset.apply(tf.data.experimental.prefetch_to_device( device='/GPU:0', buffer_size=tf.data.AUTOTUNE)) model = DemoModel(max_vocabulary_size_per_gpu=args.vocabulary_size, embedding_vec_size=args.embedding_vec_size, slot_num=args.slot_num, nnz_per_slot=args.nnz_per_slot, num_dense_layers=args.num_dense_layers, use_sok=False) optimizer = tf.keras.optimizers.Adam(learning_rate=0.1) loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE) def _replica_loss(labels, logits): loss = loss_fn(labels, logits) return tf.nn.compute_average_loss(loss, global_batch_size=args.global_batch_size) @tf.function def _train_step(inputs, labels): with tf.GradientTape() as tape: logit = model(inputs) loss = _replica_loss(labels, logit) grads = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) return loss for step, (inputs, labels) in enumerate(dataset): if (-1 != args.early_stop_iter) and (step >= args.early_stop_iter): break rng = nvtx.start_range(message="Iteration_" + str(step), color='blue') loss = _train_step(inputs, labels) tf.print("[INFO]: Iter: %d, Loss: %.5f" %(step, loss)) nvtx.end_range(rng) print("[INFO]: Profiling TF on single GPU done.")
def main(args): strategy = tf.distribute.MirroredStrategy() dataset = utility.TFDataset(filename=args.data_filename, batchsize=args.global_batch_size, as_sparse_tensor=False, repeat=1) dataset = dataset.prefetch(tf.data.AUTOTUNE) dataset = strategy.experimental_distribute_dataset(dataset) with strategy.scope(): sok.Init(global_batch_size=args.global_batch_size) model = SOKDenseDemo( max_vocabulary_size_per_gpu=args.max_vocabulary_size_per_gpu, embedding_vec_size=args.embedding_vec_size, slot_num=args.slot_num, nnz_per_slot=args.nnz_per_slot, num_dense_layers=args.num_dense_layers) embedding_optimizer = utility.get_embedding_optimizer( args.optimizer)(learning_rate=0.1) dense_optimizer = utility.get_dense_optimizer( args.optimizer)(learning_rate=0.1) loss_fn = tf.keras.losses.BinaryCrossentropy( from_logits=True, reduction=tf.keras.losses.Reduction.NONE) def _replica_loss(labels, logits): loss = loss_fn(labels, logits) return tf.nn.compute_average_loss( loss, global_batch_size=args.global_batch_size) @tf.function def _train_step(inputs, labels): with tf.GradientTape() as tape: logit = model(inputs, training=True) loss = _replica_loss(labels, logit) emb_variable, other_variable = sok.split_embedding_variable_from_others( model.trainable_variables) grads, emb_grads = tape.gradient(loss, [other_variable, emb_variable]) if "plugin" not in args.optimizer: with sok.OptimizerScope(emb_variable): embedding_optimizer.apply_gradients( zip(emb_grads, emb_variable), experimental_aggregate_gradients=False) else: embedding_optimizer.apply_gradients( zip(emb_grads, emb_variable), experimental_aggregate_gradients=False) dense_optimizer.apply_gradients(zip(grads, other_variable)) return loss for i, (inputs, labels) in enumerate(dataset): if args.stop_at_iter > 0 and i >= args.stop_at_iter: break rng = nvtx.start_range(message="Iteration_" + str(i), color="blue") replica_loss = strategy.run(_train_step, args=(inputs, labels)) loss = strategy.reduce(tf.distribute.ReduceOp.SUM, replica_loss, axis=None) nvtx.end_range(rng) print("[INFO]: Iteration: {}, loss={}".format(i, loss))
async def _run(client, args): if args.type == "gpu": import cupy as xp else: import numpy as xp # Create a simple random array rs = da.random.RandomState(RandomState=xp.random.RandomState) if args.operation == "transpose_sum": rng = start_range(message="make array(s)", color="green") x = rs.random((args.size, args.size), chunks=args.chunk_size).persist() await wait(x) end_range(rng) func_args = (x, ) func = lambda x: (x + x.T).sum() elif args.operation == "dot": rng = start_range(message="make array(s)", color="green") x = rs.random((args.size, args.size), chunks=args.chunk_size).persist() y = rs.random((args.size, args.size), chunks=args.chunk_size).persist() await wait(x) await wait(y) end_range(rng) func_args = (x, y) func = lambda x, y: x.dot(y) elif args.operation == "svd": rng = start_range(message="make array(s)", color="green") x = rs.random( (args.size, args.second_size), chunks=(int(args.chunk_size), args.second_size), ).persist() await wait(x) end_range(rng) func_args = (x, ) func = lambda x: np.linalg.svd(x) elif args.operation == "fft": rng = start_range(message="make array(s)", color="green") x = rs.random((args.size, args.size), chunks=(args.size, args.chunk_size)).persist() await wait(x) end_range(rng) func_args = (x, ) func = lambda x: np.fft.fft(x, axis=0) elif args.operation == "sum": rng = start_range(message="make array(s)", color="green") x = rs.random((args.size, args.size), chunks=args.chunk_size).persist() await wait(x) end_range(rng) func_args = (x, ) func = lambda x: x.sum() elif args.operation == "mean": rng = start_range(message="make array(s)", color="green") x = rs.random((args.size, args.size), chunks=args.chunk_size).persist() await wait(x) end_range(rng) func_args = (x, ) func = lambda x: x.mean() elif args.operation == "slice": rng = start_range(message="make array(s)", color="green") x = rs.random((args.size, args.size), chunks=args.chunk_size).persist() await wait(x) end_range(rng) func_args = (x, ) func = lambda x: x[::3].copy() elif args.operation == "col_sum": rng = start_range(message="make array(s)", color="green") x = rs.normal(10, 1, (args.size, ), chunks=args.chunk_size).persist() y = rs.normal(10, 1, (args.size, ), chunks=args.chunk_size).persist() await wait(x) await wait(y) end_range(rng) func_args = (x, y) func = lambda x, y: x + y elif args.operation == "col_mask": rng = start_range(message="make array(s)", color="green") x = rs.normal(10, 1, (args.size, ), chunks=args.chunk_size).persist() y = rs.normal(10, 1, (args.size, ), chunks=args.chunk_size).persist() await wait(x) await wait(y) end_range(rng) func_args = (x, y) func = lambda x, y: x[y > 10] elif args.operation == "col_gather": rng = start_range(message="make array(s)", color="green") x = rs.normal(10, 1, (args.size, ), chunks=args.chunk_size).persist() idx = rs.randint(0, len(x), (args.second_size, ), chunks=args.chunk_size).persist() await wait(x) await wait(idx) end_range(rng) func_args = (x, idx) func = lambda x, idx: x[idx] shape = x.shape chunksize = x.chunksize # Execute the operations to benchmark if args.profile is not None: async with performance_report(filename=args.profile): rng = start_range(message=args.operation, color="purple") t1 = clock() await wait(client.persist(func(*func_args))) if args.type == "gpu": await client.run(lambda xp: xp.cuda.Device().synchronize(), xp) took = clock() - t1 end_range(rng) else: rng = start_range(message=args.operation, color="purple") t1 = clock() await wait(client.persist(func(*func_args))) if args.type == "gpu": await client.run(lambda xp: xp.cuda.Device().synchronize(), xp) took = clock() - t1 end_range(rng) return { "took": took, "npartitions": x.npartitions, "shape": shape, "chunksize": chunksize, }
print("-" * 10) print(f"epoch {epoch + 1}/{max_epochs}") model.train() epoch_loss = 0 train_loader_iterator = iter(train_loader) for step in range(len(train_loader)): step_start = time.time() rng_train_dataload = nvtx.start_range(message="dataload", color="red") batch_data = next(train_loader_iterator) inputs, labels = ( batch_data["image"], batch_data["label"], ) nvtx.end_range(rng_train_dataload) optimizer.zero_grad() rng_train_forward = nvtx.start_range(message="forward", color="green") with torch.cuda.amp.autocast(): outputs = model(inputs) loss = loss_function(outputs, labels) nvtx.end_range(rng_train_forward) rng_train_backward = nvtx.start_range(message="backward", color="blue") scaler.scale(loss).backward() nvtx.end_range(rng_train_backward) rng_train_update = nvtx.start_range(message="update", color="yellow") scaler.step(optimizer) scaler.update() nvtx.end_range(rng_train_update)
def main(args, task_id): print("task id={}".format(task_id)) comm_options = tf.distribute.experimental.CommunicationOptions( bytes_per_pack=0, timeout_seconds=None, implementation=tf.distribute.experimental.CommunicationImplementation. NCCL) # if MirroredStrategy is used here and _train_step is not decorated by @tf.function, # there will be a "Bad file descriptor" error related to multiprocessing at the end # of the program. #if args.total_gpu_num == 1: # strategy = tf.distribute.MirroredStrategy() if True: port = 12345 os.environ["TF_CONFIG"] = json.dumps({ "cluster": { "worker": [ "localhost" + ":" + str(port + i) for i in range(args.worker_num) ] }, "task": { "type": "worker", "index": task_id } }) strategy = tf.distribute.MultiWorkerMirroredStrategy( communication_options=comm_options) if args.data_splited: filename = args.data_filename + str(task_id) + ".file" else: filename = args.data_filename replica_batch_size = args.global_batch_size // (args.worker_num * 1) dataset = utility.TFDataset(filename=filename, batchsize=replica_batch_size, as_sparse_tensor=False, repeat=1) dataset = dataset.prefetch(tf.data.AUTOTUNE) with strategy.scope(): model = TfDenseDemo(global_batch_size=args.global_batch_size, vocabulary_size=args.vocabulary_size, slot_num=args.slot_num, nnz_per_slot=args.nnz_per_slot, num_dense_layers=args.num_dense_layers, embedding_vec_size=args.embedding_vec_size) emb_optimizer = utility.get_dense_optimizer( args.optimizer)(learning_rate=0.1) dense_optimizer = utility.get_dense_optimizer( args.optimizer)(learning_rate=0.1) loss_fn = tf.keras.losses.BinaryCrossentropy( from_logits=True, reduction=tf.keras.losses.Reduction.NONE) def _replica_loss(labels, logits): loss = loss_fn(labels, logits) return tf.nn.compute_average_loss( loss, global_batch_size=args.global_batch_size) # Note: all_reduce_indexed_slices in eager mode is not supported @tf.function def _train_step(inputs, labels): with tf.GradientTape() as tape: logit = model(inputs, training=True) loss = _replica_loss(labels, logit) emb_vars, dense_vars = split_emb_and_dense_variables( model.trainable_variables) # Debug code #print("number of embedding variables: {}".format(len(emb_vars))) #print("number of dense variables : {}".format(len(dense_vars))) emb_grads, dense_grads = tape.gradient(loss, [emb_vars, dense_vars]) # update variables of embedding layer emb_optimizer.apply_gradients(zip(emb_grads, emb_vars), experimental_aggregate_gradients=False) # Mannually all-reduce dense gradients and update variables of dense layers replica_context = tf.distribute.get_replica_context() dense_grads = replica_context.all_reduce("sum", dense_grads, options=comm_options) dense_optimizer.apply_gradients(zip(dense_grads, dense_vars), experimental_aggregate_gradients=False) # manually all-reduce loss, it is ok, because replica_loss has already been used to # update local variables. loss = replica_context.all_reduce(tf.distribute.ReduceOp.SUM, loss, options=comm_options) return loss time_arr = [] for i, (inputs, labels) in enumerate(dataset): if args.stop_at_iter > 0 and i >= args.stop_at_iter: break rng = nvtx.start_range(message="Iteration_" + str(i), color="blue") start_time = time.time() loss = strategy.run(_train_step, args=(inputs, labels)) time_arr.append(time.time() - start_time) nvtx.end_range(rng) print("[INFO]: Iteration: {}, loss={}".format(i, loss)) print("Average iteration time (except 1st iteration): ", np.mean(time_arr[1:]))
def main(args): comm_options = None if "mirrored" == args.distribute_strategy: avaiable_cuda_devices = ",".join( [str(gpu_id) for gpu_id in range(args.gpu_num)]) os.environ["CUDA_VISIBLE_DEVICES"] = avaiable_cuda_devices strategy = tf.distribute.MirroredStrategy() args.task_id = 0 elif "multiworker" == args.distribute_strategy: args.task_id = int(os.getenv("OMPI_COMM_WORLD_RANK")) os.environ["CUDA_VISIBLE_DEVICES"] = str(args.task_id) args.gpu_num = int(os.getenv("OMPI_COMM_WORLD_SIZE")) comm_options = tf.distribute.experimental.CommunicationOptions( bytes_per_pack=0, timeout_seconds=None, implementation=tf.distribute.experimental. CommunicationImplementation.NCCL) import json port = 12345 os.environ["TF_CONFIG"] = json.dumps({ "cluster": { "worker": [ "localhost" + ":" + str(port + i) for i in range(args.gpu_num) ] }, "task": { "type": "worker", "index": args.task_id } }) strategy = tf.distribute.MultiWorkerMirroredStrategy( communication_options=comm_options) elif "horovod" == args.distribute_strategy: import horovod.tensorflow as hvd hvd.Init() args.task_id = hvd.local_rank() args.gpu_num = hvd.size() os.environ["CUDA_VISIBLE_DEVICES"] = str(args.task_id) strategy = utils.NullStrategy() else: raise ValueError( "Not supported distribute_strategy. " f"Can only be one of ['mirrored', 'multiworker', 'horovod']" f", but got {args.distribute_strategy}") with strategy.scope(): if args.embedding_layer == "SOK": sok.Init(global_batch_size=args.global_batch_size) model = DLRM(vocab_size=args.vocab_size_list, num_dense_features=args.num_dense_features, embedding_layer=args.embedding_layer, embedding_vec_size=args.embedding_vec_size, bottom_stack_units=args.bottom_stack, top_stack_units=args.top_stack, TF_MP=args.TF_MP, comm_options=comm_options) lr_callable = utils.get_lr_callable( global_batch_size=args.global_batch_size, decay_exp=args.decay_exp, learning_rate=args.learning_rate, warmup_steps=args.warmup_steps, decay_steps=args.decay_steps, decay_start_steps=args.decay_start_steps) embedding_optimizer = utils.get_optimizer(args.embedding_optimizer) embedding_optimizer.learning_rate = lr_callable dense_optimizer = utils.get_optimizer("Adam") batch_size = args.global_batch_size if args.distribute_strategy == "mirrored" \ else args.global_batch_size // args.gpu_num if args.distribute_strategy != "mirrored": args.train_file_pattern = utils.shard_filenames( args.train_file_pattern, args.gpu_num, args.task_id) args.test_file_pattern = utils.shard_filenames(args.test_file_pattern, args.gpu_num, args.task_id) train_dataset = CriteoTsvReader(file_pattern=args.train_file_pattern, num_dense_features=args.num_dense_features, vocab_sizes=args.vocab_size_list, batch_size=batch_size) val_dataset = CriteoTsvReader(file_pattern=args.test_file_pattern, num_dense_features=args.num_dense_features, vocab_sizes=args.vocab_size_list, batch_size=batch_size) dist_dataset = ("mirrored" == args.distribute_strategy and args.gpu_num > 1) train_dataset = utils.get_distribute_dataset( train_dataset, strategy, distribute_dataset=dist_dataset) val_dataset = utils.get_distribute_dataset(val_dataset, strategy, distribute_dataset=dist_dataset) val_dataset = iter(val_dataset) loss_fn = tf.keras.losses.BinaryCrossentropy( from_logits=True, reduction=tf.keras.losses.Reduction.NONE) def _replica_loss(labels, logits): loss = loss_fn(labels, logits) return tf.nn.compute_average_loss( loss, global_batch_size=args.global_batch_size) metrics = [ tf.keras.metrics.AUC(name="auc"), tf.keras.metrics.BinaryAccuracy(name="accuracy"), tf.keras.metrics.Mean("prediction_mean"), tf.keras.metrics.Mean("label_mean") ] metrics_threshold = {"auc": 0.8025} @tf.function def _train_step(features, labels, first_batch=False): with tf.GradientTape() as tape: logits = model(features, training=True) loss = _replica_loss(labels, logits) emb_vars, other_vars = utils.split_embedding_variables_from_others( model) emb_grads, other_grads = tape.gradient(loss, [emb_vars, other_vars]) with tf.control_dependencies([logits] + emb_grads): utils.apply_gradients(embedding_optimizer, emb_vars, emb_grads, args.embedding_layer == "SOK", aggregate_gradients=(not args.TF_MP)) other_grads = utils.all_reduce(other_grads, combiner="sum", comm_options=comm_options) utils.apply_gradients(dense_optimizer, other_vars, other_grads, False) if first_batch: utils.broadcast_variables(other_vars) utils.broadcast_variables(dense_optimizer.variables()) if args.embedding_layer == "TF": utils.broadcast_variables(emb_vars) utils.broadcast_variables(embedding_optimizer.variables()) total_loss = utils.all_reduce(loss, combiner="sum", comm_options=comm_options) return total_loss @tf.function def _val_step(features, labels, metrics): val_logits = model(features, training=False) val_loss = _replica_loss(labels, val_logits) val_loss = utils.all_reduce(val_loss, combiner="sum", comm_options=comm_options) labels = tf.identity(labels) val_logits = utils.all_gather(val_logits, axis=0, comm_options=comm_options) labels = utils.all_gather(labels, axis=0, comm_options=comm_options) return val_logits, labels, val_loss stopper = utils.EarlyStopper() begin_time = time.time() start_time = begin_time nvtx_began = False steps = 0 for i, (features, labels) in enumerate(train_dataset): steps = i if not nvtx_began and i >= args.nvtx_begin_step: capture_range = nvtx.start_range(color="red", message="Capture", domain="Capture") nvtx_began = True iter_range = nvtx.start_range(message="Iteration_" + str(i), color="blue") if i >= args.train_steps: break if stopper.should_stop(): print(stopper.stop_reason) break total_loss = strategy.run(_train_step, args=(features, labels, i == 0)) if i % args.validation_interval == 0 and i != 0: val_features, val_labels = next(val_dataset) val_logits, val_labels, val_loss =\ strategy.run(_val_step, args=(val_features, val_labels, metrics)) if hasattr(val_labels, "values"): val_labels = val_labels.values[0] val_logits = val_logits.values[0] update_metrics_states(y_true=val_labels, y_pred=val_logits, metrics=metrics) val_logs = train_loop_end(metrics, total_loss, val_loss, embedding_optimizer, dense_optimizer, global_step=i) elapsed_time = time.time() - begin_time steps_sec = args.validation_interval / elapsed_time utils.show_logs(val_logs, strategy, elapsed_time, steps_sec, metrics_threshold, stopper) begin_time = time.time() nvtx.end_range(iter_range) if nvtx_began and capture_range and i >= (args.nvtx_end_step - 1): nvtx.end_range(capture_range) capture_range = None end_time = time.time() if args.task_id == 0: print( f"With {args.distribute_strategy} + {args.embedding_layer} embedding layer, " f"on {args.gpu_num} GPUs, and global_batch_size is {args.global_batch_size}, " f"it takes {end_time - start_time} seconds to " f"finish {steps} steps training for DLRM.")
def main(args, task_id): print(task_id) comm_options = tf.distribute.experimental.CommunicationOptions( bytes_per_pack=0, timeout_seconds=None, implementation=tf.distribute.experimental.CommunicationImplementation. NCCL) if args.total_gpu_num == 1: strategy = tf.distribute.MirroredStrategy() else: port = 12345 os.environ["TF_CONFIG"] = json.dumps({ "cluster": { "worker": [ "localhost" + ":" + str(port + i) for i in range(args.worker_num) ] }, "task": { "type": "worker", "index": task_id } }) strategy = tf.distribute.MultiWorkerMirroredStrategy( communication_options=comm_options) if args.data_splited: filename = args.data_filename + str(task_id) + ".file" else: filename = args.data_filename replica_batch_size = args.global_batch_size // (args.worker_num * 1) dataset = utility.TFDataset(filename=filename, batchsize=replica_batch_size, as_sparse_tensor=False, repeat=1) dataset = dataset.prefetch(tf.data.AUTOTUNE) with strategy.scope(): sok.Init(global_batch_size=args.global_batch_size) model = SOKDenseDemo( max_vocabulary_size_per_gpu=args.max_vocabulary_size_per_gpu, embedding_vec_size=args.embedding_vec_size, slot_num=args.slot_num, nnz_per_slot=args.nnz_per_slot, num_dense_layers=args.num_dense_layers) embedding_optimizer = utility.get_embedding_optimizer( args.optimizer)(learning_rate=0.1) dense_optimizer = utility.get_dense_optimizer( args.optimizer)(learning_rate=0.1) loss_fn = tf.keras.losses.BinaryCrossentropy( from_logits=True, reduction=tf.keras.losses.Reduction.NONE) def _replica_loss(labels, logits): loss = loss_fn(labels, logits) return tf.nn.compute_average_loss( loss, global_batch_size=args.global_batch_size) @tf.function def _train_step(inputs, labels): with tf.GradientTape() as tape: logit = model(inputs, training=True) loss = _replica_loss(labels, logit) emb_variable, other_variable = sok.split_embedding_variable_from_others( model.trainable_variables) grads, emb_grads = tape.gradient(loss, [other_variable, emb_variable]) if "plugin" not in args.optimizer: with sok.OptimizerScope(emb_variable): embedding_optimizer.apply_gradients( zip(emb_grads, emb_variable), experimental_aggregate_gradients=False) else: embedding_optimizer.apply_gradients( zip(emb_grads, emb_variable), experimental_aggregate_gradients=False) # mannually all-reduce dense gradients replica_context = tf.distribute.get_replica_context() grads = replica_context.all_reduce("sum", grads, options=comm_options) dense_optimizer.apply_gradients(zip(grads, other_variable), experimental_aggregate_gradients=False) # manually all-reduce loss, it is ok, because replica_loss has already been used to # update local variables. loss = replica_context.all_reduce(tf.distribute.ReduceOp.SUM, loss, options=comm_options) return loss time_arr = [] for i, (inputs, labels) in enumerate(dataset): if args.stop_at_iter > 0 and i >= args.stop_at_iter: break rng = nvtx.start_range(message="Iteration_" + str(i), color="blue") start_time = time.time() total_loss = strategy.run(_train_step, args=(inputs, labels)) time_arr.append(time.time() - start_time) nvtx.end_range(rng) if task_id == '0': print("[INFO]: Iteration: {}, loss={}".format(i, total_loss)) if task_id == '0': print("Average iteration time (except 1st iteration): ", np.mean(time_arr[1:]))
def main(args): # Initialize horovod hvd.init() gpus = tf.config.list_physical_devices("GPU") tf.config.set_visible_devices(gpus[hvd.local_rank()], "GPU") # Generate local filename # Assume the dataset has been splited in advance local_file = args.data_filename_prefix + str(hvd.local_rank()) + ".file" # generate local batch size assert (args.global_batch_size % hvd.size() == 0) local_batch_size = args.global_batch_size // hvd.size() dataset = utility.TFDataset(filename=local_file, batchsize=local_batch_size, as_sparse_tensor=False, repeat=1) dataset = dataset.prefetch(tf.data.AUTOTUNE) # Because there is no tensorflow distribute strategy, sok.Init() will call horovod to # broadcast nccl id and random seed, so it must be called after hvd.init() sok.Init(global_batch_size=args.global_batch_size) model = SOKDenseDemo( max_vocabulary_size_per_gpu=args.max_vocabulary_size_per_gpu, embedding_vec_size=args.embedding_vec_size, slot_num=args.slot_num, nnz_per_slot=args.nnz_per_slot, num_dense_layers=args.num_dense_layers) embedding_optimizer = utility.get_embedding_optimizer( args.optimizer)(learning_rate=0.1) dense_optimizer = utility.get_dense_optimizer( args.optimizer)(learning_rate=0.1) loss_fn = tf.keras.losses.BinaryCrossentropy( from_logits=True, reduction=tf.keras.losses.Reduction.NONE) def _replica_loss(labels, logits): loss = loss_fn(labels, logits) return tf.nn.compute_average_loss( loss, global_batch_size=args.global_batch_size) @tf.function def _train_step(inputs, labels, first_batch): with tf.GradientTape() as tape, tf.GradientTape() as emb_tape: logit = model(inputs, training=True) replica_loss = _replica_loss(labels, logit) # Horovod: wrap tf.GradientTape with Horovod DistributedGradientTape tape = hvd.DistributedGradientTape(tape) # There is no need to wrap the emb_tape because the communication is done by sok # emb_tape = hvd.DistributedGradientTape(emb_tape) emb_variable, other_variable = sok.split_embedding_variable_from_others( model.trainable_variables) # type(emb_tape) here is hvd.DistributedGradientTape # type(tape) here is tf.GradientTape emb_grads = emb_tape.gradient(replica_loss, emb_variable) grads = tape.gradient(replica_loss, other_variable) if "plugin" not in args.optimizer: with sok.OptimizerScope(emb_variable): embedding_optimizer.apply_gradients( zip(emb_grads, emb_variable), experimental_aggregate_gradients=False) else: embedding_optimizer.apply_gradients( zip(emb_grads, emb_variable), experimental_aggregate_gradients=False) dense_optimizer.apply_gradients(zip(grads, other_variable)) # Note: broadcast should be done after the first gradient step to ensure optimizer has been initialized. # There is no need to broadcast emb_variable and embedding_optimizer, because the parallel mode inside # sok is model parallel and the communication is down by sok itself. if first_batch: hvd.broadcast_variables(other_variable, root_rank=0) hvd.broadcast_variables(dense_optimizer.variables(), root_rank=0) return replica_loss for i, (inputs, labels) in enumerate(dataset): if args.stop_at_iter > 0 and i >= args.stop_at_iter: break rng = nvtx.start_range(message="Iteration_" + str(i), color="blue") total_loss = _train_step(inputs, labels, i == 0) nvtx.end_range(rng) print("[INFO]: Iteration: {}, loss={}".format(i, total_loss))