def _ComputeLossesAndGradients(self, metrics, vmap): p = self.params vmap = self.GetTrainableVariables(vmap) for v in vmap.Flatten(): tf.logging.info('%s: bprop variable: %s', p.name, v.name) def LossAndGradients(metric_name): """Returns (loss, var_grads) computed from metrics[metric_name].""" metric = metrics.get(metric_name, None) if metric is None: raise ValueError('Loss %s not found in metrics %s' % (metric_name, list(metrics.keys()))) # TODO(b/154785713): pass (loss, loss_weight) to ComputeGradients(). loss = metric[0] return metric, self.optimizer.ComputeGradients( loss, vmap, p.grad_aggregation_method, p.colocate_gradients_with_ops, p.gate_gradients, compute_gradients_fn=self._CustomComputeGradientsFn(), skip_zero_gradients=p.skip_zero_gradients, skip_none_gradients=False) loss_name = p.loss_name or p.name losses = [] eval_metrics = {} if isinstance(loss_name, (list, tuple)): losses_and_grads = {} variables = None for metric_name in loss_name: loss_metric, var_grads = LossAndGradients(metric_name) losses_and_grads[metric_name] = py_utils.NestedMap( loss_metric=loss_metric, grads=tf.nest.map_structure(lambda vg: vg.grad, var_grads)) current_vars = tf.nest.map_structure(lambda vg: vg.var, var_grads) if variables is None: variables = current_vars else: tf.nest.assert_same_structure(variables, current_vars) losses.append(loss_metric[0]) grads, eval_metrics = self.gradient_combiner.Combine( variables, losses_and_grads) var_grads = tf.nest.map_structure( lambda v, g: py_utils.VarGrad(var=v, grad=g), variables, grads) else: loss_metric, var_grads = LossAndGradients(loss_name) losses.append(loss_metric[0]) return losses, py_utils.SkipNoneGradients(var_grads), eval_metrics
def _ComputeLossesAndGradients(self, metrics, vmap): p = self.params vmap = self.GetTrainableVariables(vmap) # Get tpu embedding activations to compute the gradients for. tpu_embedding_activations = py_utils.NestedMap() tpu_embedding_graph_collection = py_utils.GetTpuEmbeddingGraphCollection( ) if tpu_embedding_graph_collection: tpu_embedding_collection = tpu_embedding_graph_collection[0] task_call_scope = py_utils.GetTaskCallScope() tpu_embedding_activations = py_utils.NestedMap( tpu_embedding_collection.GetActivations(task_call_scope) or {}) # It's possible that task_call_scope is None and its mode is not set in # tpu_embedding_collection (e.g. in unit test), but if the activation is # not empty, the mode must have been set. if tpu_embedding_activations and ( tpu_embedding_collection.ShouldStopGradient( task_call_scope)): tpu_embedding_activations = py_utils.NestedMap() for v in vmap.Flatten(): tf.logging.info('%s: bprop variable: %s', p.name, v.name) def LossAndGradients(metric_name): """Returns (loss, var_grads) computed from metrics[metric_name].""" metric = metrics.get(metric_name, None) if metric is None: raise ValueError('Loss %s not found in metrics %s' % (metric_name, list(metrics.keys()))) # TODO(b/154785713): pass (loss, loss_weight) to ComputeGradients(). loss = metric[0] return metric, self.optimizer.ComputeGradients( loss, vmap, p.grad_aggregation_method, p.colocate_gradients_with_ops, p.gate_gradients, compute_gradients_fn=self._CustomComputeGradientsFn(), skip_zero_gradients=p.skip_zero_gradients, skip_none_gradients=False, tpu_embedding_activations=tpu_embedding_activations) loss_name = p.loss_name or p.name losses = [] eval_metrics = {} if isinstance(loss_name, (list, tuple)): assert not tpu_embedding_activations, ( 'TPU embedding does not support multiple loss currently.') losses_and_grads = {} variables = None for metric_name in loss_name: loss_metric, var_grads = LossAndGradients(metric_name) losses_and_grads[metric_name] = py_utils.NestedMap( loss_metric=loss_metric, grads=tf.nest.map_structure(lambda vg: vg.grad, var_grads)) current_vars = tf.nest.map_structure(lambda vg: vg.var, var_grads) if variables is None: variables = current_vars else: tf.nest.assert_same_structure(variables, current_vars) losses.append(loss_metric[0]) grads, eval_metrics = self.gradient_combiner.Combine( variables, losses_and_grads) var_grads = tf.nest.map_structure( lambda v, g: py_utils.VarGrad(var=v, grad=g), variables, grads) else: loss_metric, var_grads = LossAndGradients(loss_name) losses.append(loss_metric[0]) return losses, py_utils.SkipNoneGradients(var_grads), eval_metrics