def step_fn(features): """Computes loss and applied gradient per replica.""" with tf.GradientTape() as tape: softmax_logits = keras_model(features) labels = features[rconst.TRAIN_LABEL_KEY] loss = loss_object(labels, softmax_logits, sample_weight=features[rconst.VALID_POINT_MASK]) loss *= (1.0 / (batch_size*strategy.num_replicas_in_sync)) grads = tape.gradient(loss, keras_model.trainable_variables) # Converting gradients to dense form helps in perf on GPU for NCF grads = neumf_model.sparse_to_dense_grads( list(zip(grads, keras_model.trainable_variables))) optimizer.apply_gradients(grads) return loss
def step_fn(features): """Computes loss and applied gradient per replica.""" with tf.GradientTape() as tape: softmax_logits = keras_model(features) labels = features[rconst.TRAIN_LABEL_KEY] loss = loss_object( labels, softmax_logits, sample_weight=features[rconst.VALID_POINT_MASK]) loss *= (1.0 / params["batch_size"]) if FLAGS.dtype == "fp16": loss = optimizer.get_scaled_loss(loss) grads = tape.gradient(loss, keras_model.trainable_variables) if FLAGS.dtype == "fp16": grads = optimizer.get_unscaled_gradients(grads) # Converting gradients to dense form helps in perf on GPU for NCF grads = neumf_model.sparse_to_dense_grads( list(zip(grads, keras_model.trainable_variables))) optimizer.apply_gradients(grads) return loss