def stage2(lr, partial, labels): logits = keras.layers.Dense(10)(partial) per_example_loss = keras.losses.sparse_categorical_crossentropy( y_true=labels, y_pred=logits, from_logits=True) # In a custom training loop, the optimiser does an allreduce *sum*, not # average, of the gradients across the distributed workers. Therefore # we want to divide the loss here by the *global* batch size, which is # done by the `tf.nn.compute_average_loss()` function. loss = nn.compute_average_loss(per_example_loss) return lr, loss
def step_fn(inputs): """The computation to run on each worker.""" features, labels = inputs with backprop.GradientTape() as tape: pred = model(features, training=True) loss = keras.losses.binary_crossentropy(labels, pred) loss = nn.compute_average_loss(loss) grads = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(list(zip(grads, model.trainable_variables))) actual_pred = math_ops.cast(math_ops.greater(pred, 0.5), dtypes.int64) accuracy.update_state(labels, actual_pred)
def replica_fn(iterator): batch_data, labels = next(iterator) with backprop.GradientTape() as tape: pred = model(batch_data, training=True) loss = nn.compute_average_loss( keras.losses.BinaryCrossentropy( reduction=losses_utils.ReductionV2.NONE)(labels, pred)) gradients = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) actual_pred = math_ops.cast(math_ops.greater(pred, 0.5), dtypes.int64) accuracy.update_state(labels, actual_pred)
def compute_loss(labels, logits, reg_losses): pred_loss = keras.losses.mean_squared_error(labels, logits) scaled_loss = nn.compute_average_loss(pred_loss, global_batch_size=_BATCH_SIZE) l2_loss = nn.scale_regularization_loss(reg_losses) return scaled_loss + l2_loss