def training_step(images, labels, batch): """ Runs one training step, with considerations for the distributed nature of the job. :param images: images, sorted :param labels: labels for images, same order :param batch: batch number. variables must be broadcasted on first batch :return: loss for this step across all workers """ with tf.GradientTape() as tape: # record losses for automatic differentiation (learning) probs = model(images, training=True) loss_value = loss(labels, probs) # need to wrap tape with more tape, because it is distributed tape = dist.DistributedGradientTape(tape) grads = tape.gradient(loss_value, model.trainable_variables) optimizer.apply_gradients(zip( grads, model.trainable_variables)) # learn on this worker if batch == 0: # workers share learning broadcast by broadcasting variables. one set of variables across workers. dist.broadcast_variables(model.variables, root_rank=0) dist.broadcast_variables(optimizer.variables(), root_rank=0) # all loss values from all workers reduced to one loss value loss_value = dist.oob_allreduce( loss_value) # Average the loss across workers return loss_value
def fit(model, loss, opt, train_dataset, epochs, train_batch_size, max_steps=None): pbar = tqdm(train_dataset) for i, batch in enumerate(pbar): with tf.GradientTape() as tape: inputs, targets = batch outputs = model(batch) loss_value = loss(targets, outputs.logits) if SDP_ENABLED: tape = sdp.DistributedGradientTape(tape, sparse_as_dense=True) grads = tape.gradient(loss_value, model.trainable_variables) opt.apply_gradients(zip(grads, model.trainable_variables)) pbar.set_description(f"Loss: {loss_value:.4f}") if SDP_ENABLED: if i == 0: sdp.broadcast_variables(model.variables, root_rank=0) sdp.broadcast_variables(opt.variables(), root_rank=0) first_batch = False if max_steps and i >= max_steps: break train_results = {"loss": loss_value.numpy()} return train_results
def training_step(mnist_model, loss, opt, images, labels, batch): with tf.GradientTape() as tape: probs = mnist_model(images, training=True) loss_value = loss(labels, probs) ######################################################## ####### 4. SageMaker Distributed Data Parallel ######## ####### - Optimize AllReduce operation ######## ######################################################## # SMDataParallel: Wrap tf.GradientTape with SMDataParallel's DistributedGradientTape tape = smdp.DistributedGradientTape(tape) ####################################################### grads = tape.gradient(loss_value, mnist_model.trainable_variables) opt.apply_gradients(zip(grads, mnist_model.trainable_variables)) ######################################################## ####### 5. SageMaker Distributed Data Parallel ####### ####### - Broadcast the initial model variables ####### ####### from rank 0 to ranks 1 ~ n ####### ######################################################## if batch == 0: # SMDataParallel: Broadcast model and optimizer variables smdp.broadcast_variables(mnist_model.variables, root_rank=0) smdp.broadcast_variables(opt.variables(), root_rank=0) ####################################################### return loss_value
def training_step(images, labels, first_batch): with tf.GradientTape() as tape: probs = mnist_model(images, training=True) loss_value = loss(labels, probs) tape = dist.DistributedGradientTape(tape) grads = tape.gradient(loss_value, mnist_model.trainable_variables) opt.apply_gradients(zip(grads, mnist_model.trainable_variables)) if first_batch: dist.broadcast_variables(mnist_model.variables, root_rank=0) dist.broadcast_variables(opt.variables(), root_rank=0) loss_value = dist.oob_allreduce( loss_value) # Average the loss across workers return loss_value
def training_step(images, labels, first_batch): with tf.GradientTape() as tape: probs = mnist_model(images, training=True) loss_value = loss(labels, probs) # Create a new DistributedGradientTape, which uses TensorFlow’s GradientTape under the hood, # using an AllReduce to combine gradient values before applying gradients to model weights. tape = smdataparallel.DistributedGradientTape(tape) grads = tape.gradient(loss_value, mnist_model.trainable_variables) opt.apply_gradients(zip(grads, mnist_model.trainable_variables)) # Broadcast model and optimizer variable are first forward pass for sync if first_batch: smdataparallel.broadcast_variables(mnist_model.variables, root_rank=0) smdataparallel.broadcast_variables(opt.variables(), root_rank=0) return loss_value
def training_step(images, labels, first_batch): with tf.GradientTape() as tape: probs = mnist_model(images, training=True) loss_value = loss(labels, probs) # SMDataParallel: Wrap tf.GradientTape with SMDataParallel's DistributedGradientTape tape = dist.DistributedGradientTape(tape) grads = tape.gradient(loss_value, mnist_model.trainable_variables) opt.apply_gradients(zip(grads, mnist_model.trainable_variables)) if first_batch: # SMDataParallel: Broadcast model and optimizer variables dist.broadcast_variables(mnist_model.variables, root_rank=0) dist.broadcast_variables(opt.variables(), root_rank=0) # SMDataParallel: all_reduce call loss_value = dist.oob_allreduce( loss_value) # Average the loss across workers return loss_value
def training_step(images, labels, first_batch): with tf.GradientTape() as tape: train_pred = model(images, training=True) loss_value = loss(labels, train_pred) # Change: Wrap tf.GradientTape with SMDataParallel's DistributedGradientTape tape = smdp.DistributedGradientTape(tape) grads = tape.gradient(loss_value, model.trainable_variables) opt.apply_gradients(zip(grads, model.trainable_variables)) if first_batch: # Change: Broadcast model and optimizer variables smdp.broadcast_variables(model.variables, root_rank=0) smdp.broadcast_variables(opt.variables(), root_rank=0) # Change: all_reduce call train_loss_value = smdp.oob_allreduce( loss_value) # Average the loss across workers train_loss(train_loss_value) train_accuracy(labels, train_pred) return