예제 #1
0
    def _ComputeLossesAndGradients(self, metrics, vmap):
        p = self.params
        vmap = self.GetTrainableVariables(vmap)

        for v in vmap.Flatten():
            tf.logging.info('%s: bprop variable: %s', p.name, v.name)

        def LossAndGradients(metric_name):
            """Returns (loss, var_grads) computed from metrics[metric_name]."""
            metric = metrics.get(metric_name, None)
            if metric is None:
                raise ValueError('Loss %s not found in metrics %s' %
                                 (metric_name, list(metrics.keys())))
            # TODO(b/154785713): pass (loss, loss_weight) to ComputeGradients().
            loss = metric[0]
            return metric, self.optimizer.ComputeGradients(
                loss,
                vmap,
                p.grad_aggregation_method,
                p.colocate_gradients_with_ops,
                p.gate_gradients,
                compute_gradients_fn=self._CustomComputeGradientsFn(),
                skip_zero_gradients=p.skip_zero_gradients,
                skip_none_gradients=False)

        loss_name = p.loss_name or p.name
        losses = []
        eval_metrics = {}
        if isinstance(loss_name, (list, tuple)):
            losses_and_grads = {}
            variables = None
            for metric_name in loss_name:
                loss_metric, var_grads = LossAndGradients(metric_name)
                losses_and_grads[metric_name] = py_utils.NestedMap(
                    loss_metric=loss_metric,
                    grads=tf.nest.map_structure(lambda vg: vg.grad, var_grads))
                current_vars = tf.nest.map_structure(lambda vg: vg.var,
                                                     var_grads)
                if variables is None:
                    variables = current_vars
                else:
                    tf.nest.assert_same_structure(variables, current_vars)
                losses.append(loss_metric[0])

            grads, eval_metrics = self.gradient_combiner.Combine(
                variables, losses_and_grads)
            var_grads = tf.nest.map_structure(
                lambda v, g: py_utils.VarGrad(var=v, grad=g), variables, grads)
        else:
            loss_metric, var_grads = LossAndGradients(loss_name)
            losses.append(loss_metric[0])

        return losses, py_utils.SkipNoneGradients(var_grads), eval_metrics
예제 #2
0
    def _ComputeLossesAndGradients(self, metrics, vmap):
        p = self.params
        vmap = self.GetTrainableVariables(vmap)

        # Get tpu embedding activations to compute the gradients for.
        tpu_embedding_activations = py_utils.NestedMap()
        tpu_embedding_graph_collection = py_utils.GetTpuEmbeddingGraphCollection(
        )
        if tpu_embedding_graph_collection:
            tpu_embedding_collection = tpu_embedding_graph_collection[0]
            task_call_scope = py_utils.GetTaskCallScope()
            tpu_embedding_activations = py_utils.NestedMap(
                tpu_embedding_collection.GetActivations(task_call_scope) or {})
            # It's possible that task_call_scope is None and its mode is not set in
            # tpu_embedding_collection (e.g. in unit test), but if the activation is
            # not empty, the mode must have been set.
            if tpu_embedding_activations and (
                    tpu_embedding_collection.ShouldStopGradient(
                        task_call_scope)):
                tpu_embedding_activations = py_utils.NestedMap()

        for v in vmap.Flatten():
            tf.logging.info('%s: bprop variable: %s', p.name, v.name)

        def LossAndGradients(metric_name):
            """Returns (loss, var_grads) computed from metrics[metric_name]."""
            metric = metrics.get(metric_name, None)
            if metric is None:
                raise ValueError('Loss %s not found in metrics %s' %
                                 (metric_name, list(metrics.keys())))
            # TODO(b/154785713): pass (loss, loss_weight) to ComputeGradients().
            loss = metric[0]
            return metric, self.optimizer.ComputeGradients(
                loss,
                vmap,
                p.grad_aggregation_method,
                p.colocate_gradients_with_ops,
                p.gate_gradients,
                compute_gradients_fn=self._CustomComputeGradientsFn(),
                skip_zero_gradients=p.skip_zero_gradients,
                skip_none_gradients=False,
                tpu_embedding_activations=tpu_embedding_activations)

        loss_name = p.loss_name or p.name
        losses = []
        eval_metrics = {}
        if isinstance(loss_name, (list, tuple)):
            assert not tpu_embedding_activations, (
                'TPU embedding does not support multiple loss currently.')
            losses_and_grads = {}
            variables = None
            for metric_name in loss_name:
                loss_metric, var_grads = LossAndGradients(metric_name)
                losses_and_grads[metric_name] = py_utils.NestedMap(
                    loss_metric=loss_metric,
                    grads=tf.nest.map_structure(lambda vg: vg.grad, var_grads))
                current_vars = tf.nest.map_structure(lambda vg: vg.var,
                                                     var_grads)
                if variables is None:
                    variables = current_vars
                else:
                    tf.nest.assert_same_structure(variables, current_vars)
                losses.append(loss_metric[0])

            grads, eval_metrics = self.gradient_combiner.Combine(
                variables, losses_and_grads)
            var_grads = tf.nest.map_structure(
                lambda v, g: py_utils.VarGrad(var=v, grad=g), variables, grads)
        else:
            loss_metric, var_grads = LossAndGradients(loss_name)
            losses.append(loss_metric[0])

        return losses, py_utils.SkipNoneGradients(var_grads), eval_metrics