def training_step(images, labels, batch): """ Runs one training step, with considerations for the distributed nature of the job. :param images: images, sorted :param labels: labels for images, same order :param batch: batch number. variables must be broadcasted on first batch :return: loss for this step across all workers """ with tf.GradientTape() as tape: # record losses for automatic differentiation (learning) probs = model(images, training=True) loss_value = loss(labels, probs) # need to wrap tape with more tape, because it is distributed tape = dist.DistributedGradientTape(tape) grads = tape.gradient(loss_value, model.trainable_variables) optimizer.apply_gradients(zip( grads, model.trainable_variables)) # learn on this worker if batch == 0: # workers share learning broadcast by broadcasting variables. one set of variables across workers. dist.broadcast_variables(model.variables, root_rank=0) dist.broadcast_variables(optimizer.variables(), root_rank=0) # all loss values from all workers reduced to one loss value loss_value = dist.oob_allreduce( loss_value) # Average the loss across workers return loss_value
def allreduce(model, optimizer, gradient_accumulator, loss, mlm_loss, mlm_acc, sop_loss, sop_acc): scaled_grads = gradient_accumulator.gradients grads = optimizer.get_unscaled_gradients(scaled_grads) # This, which is equivalent to sparse_as_dense=True, gives a mild 2% speedup from 0.62 it/s to 0.63 it/s # on BERT-large multinode. grads = [ tf.convert_to_tensor(grad) if grad is not None and isinstance(grad, tf.IndexedSlices) else grad for grad in grads ] # TODO: Does placing this clip before or after allreduce affect accuracy? # Placing before has a regularization effect, no single example can contribute as much. # Placing before also gives a 20% speedup when training BERT-large, probably because the # gradient operations can be fused by XLA. (grads, grad_norm) = tf.clip_by_global_norm(grads, clip_norm=max_grad_norm) weight_norm = tf.math.sqrt( tf.math.reduce_sum( [tf.norm(var, ord=2)**2 for var in model.trainable_variables])) grads = [ smddp.allreduce(grad, param_index=idx, num_params=len(grads), compression=smddp.Compression.fp16) if grad is not None else None for idx, grad in enumerate(grads) ] optimizer.apply_gradients([ (tf.cast(grad, var.dtype), var) for (grad, var) in zip(grads, model.trainable_variables) if grad is not None ]) # Clear the gradient accumulator gradient_accumulator.reset() loss = smddp.oob_allreduce(loss) mlm_loss = smddp.oob_allreduce(mlm_loss) mlm_acc = smddp.oob_allreduce(mlm_acc) sop_loss = smddp.oob_allreduce(sop_loss) sop_acc = smddp.oob_allreduce(sop_acc) return loss, mlm_loss, mlm_acc, sop_loss, sop_acc, grad_norm, weight_norm
def training_step(images, labels, first_batch): with tf.GradientTape() as tape: probs = model(images, training=True) loss_value = loss_fn(labels, probs) acc_value = acc(labels, probs) # SMDataParallel: Wrap tf.GradientTape with SMDataParallel's DistributedGradientTape tape = dist.DistributedGradientTape(tape) grads = tape.gradient(loss_value, model.trainable_variables) opt.apply_gradients(zip(grads, model.trainable_variables)) if first_batch: # SMDataParallel: Broadcast model and optimizer variables dist.broadcast_variables(model.variables, root_rank=0) dist.broadcast_variables(opt.variables(), root_rank=0) # SMDataParallel: all_reduce call loss_value = dist.oob_allreduce( loss_value) # Average the loss across workers acc_value = dist.oob_allreduce(acc_value) return loss_value, acc_value
def training_step(images, labels, first_batch): with tf.GradientTape() as tape: probs = mnist_model(images, training=True) loss_value = loss(labels, probs) tape = dist.DistributedGradientTape(tape) grads = tape.gradient(loss_value, mnist_model.trainable_variables) opt.apply_gradients(zip(grads, mnist_model.trainable_variables)) if first_batch: dist.broadcast_variables(mnist_model.variables, root_rank=0) dist.broadcast_variables(opt.variables(), root_rank=0) loss_value = dist.oob_allreduce( loss_value) # Average the loss across workers return loss_value
def training_step(images, labels, first_batch): with tf.GradientTape() as tape: train_pred = model(images, training=True) loss_value = loss(labels, train_pred) # Change: Wrap tf.GradientTape with SMDataParallel's DistributedGradientTape tape = smdp.DistributedGradientTape(tape) grads = tape.gradient(loss_value, model.trainable_variables) opt.apply_gradients(zip(grads, model.trainable_variables)) if first_batch: # Change: Broadcast model and optimizer variables smdp.broadcast_variables(model.variables, root_rank=0) smdp.broadcast_variables(opt.variables(), root_rank=0) # Change: all_reduce call train_loss_value = smdp.oob_allreduce( loss_value) # Average the loss across workers train_loss(train_loss_value) train_accuracy(labels, train_pred) return