示例#1
0
    def _TpuEmbLookup(self) -> Dict[str, tf.Tensor]:
        """TPU Embedding lookup."""
        activations = self._tpu_embedding.get_activations()
        task = py_utils.GetTaskCallScope()
        # We expect either None (if this is the first call) or a single item in a
        # list.
        tpu_embedding_activations = tf.get_collection(
            py_utils.TPU_EMBEDDING_ACTIVATIONS)
        if not tpu_embedding_activations:
            # Create a dict from task -> activations dict.
            tpu_embedding_activations_dict = {}
            tpu_embedding_activations_dict[task] = activations
            tf.add_to_collection(py_utils.TPU_EMBEDDING_ACTIVATIONS,
                                 tpu_embedding_activations_dict)
        else:
            # This is a subsequent call, so the dictionary already exists.
            tpu_embedding_activations_dict = tpu_embedding_activations[0]
            tpu_embedding_activations_dict[task] = activations

        ret = py_utils.NestedMap()
        for k, v in activations.items():
            if k in self._sequence_features:
                ret[k] = v
            else:
                # Non-sequence embeddings, we fill the "time" dimension with 1.
                ret[k] = tf.expand_dims(v, axis=[1])
        return ret
示例#2
0
文件: learner.py 项目: Mddct/lingvo
    def Apply(self, metrics, vmap, gradient_mask=None, gradient_adjuster=None):
        """Computes updates on 'vmap' to optimize 'loss'.

    TODO(rpang): explore merging gradient_mask and gradient_adjuster.

    Args:
      metrics: A Dict[str, (value, weight)], from which loss can be extracted
        according to p.loss_name.
      vmap: A `.NestedMap` object containing variables to optimize.
      gradient_mask: if not None, a dict mapping variable names to a 0/1 scalar.
      gradient_adjuster: if not None, a function that mutates a given var_grads.

    Returns:
      (losses, op, eval_metrics), where
        - losses is a list of scalar tensors;
        - op is a tf.Operation to update variables;
        - eval_metrics is a Dict[str, (value, weight)], where each value/weight
          is a scalar tensor.
    """
        # We apply gradients outside the name_scope to maintain backwards
        # compatibility on variables created by self.optimizer.Apply().
        losses, var_grads, eval_metrics = self._ComputeLossesAndGradients(
            metrics, vmap)
        if 'tpu_embedding_var_grads' in var_grads:
            tpu_embedding_var_grads = var_grads.tpu_embedding_var_grads
            del var_grads.tpu_embedding_var_grads

            tpu_embedding_collection = py_utils.GetTpuEmbeddingGraphCollection(
            )[0]
            assert tpu_embedding_collection
            tpu_emb_update_op, stats = tpu_embedding_collection.ApplyGradients(
                py_utils.GetTaskCallScope(),
                tpu_embedding_var_grads.Transform(
                    lambda var_grad: var_grad.grad))
            eval_metrics.update(stats)
        else:
            tpu_emb_update_op = tf.no_op()

        assert py_utils.GetGlobalStep() is not None
        lr = self.LearningRate()

        var_grads, stats = self.AdjustGradients(
            var_grads,
            gradient_mask=gradient_mask,
            gradient_adjuster=gradient_adjuster)
        eval_metrics.update(stats)
        self._var_grads = var_grads

        eval_metrics['learning_rate'] = (tf.convert_to_tensor(lr),
                                         tf.convert_to_tensor(1.))

        var_update_op = tf.group(
            [tpu_emb_update_op,
             self.optimizer.Apply(lr, var_grads)])
        return losses, var_update_op, eval_metrics
示例#3
0
  def _TpuEmbLookup(self) -> Dict[str, tf.Tensor]:
    """TPU Embedding lookup."""
    activations = self._tpu_embedding.get_activations()
    task = py_utils.GetTaskCallScope()
    self._tpu_embedding_collection.AddActivations(task, activations)

    ret = py_utils.NestedMap()
    for k, v in activations.items():
      if k in self._sequence_features:
        ret[k] = v
      else:
        # Non-sequence embeddings, we fill the "time" dimension with 1.
        ret[k] = tf.expand_dims(v, axis=[1])
    return ret
示例#4
0
  def _TpuEmbLookup(self, ids_map: py_utils.NestedMap) -> py_utils.NestedMap:
    """TPU Embedding lookup."""
    task_call_scope = py_utils.GetTaskCallScope()
    activations = self._tpu_embedding_collection.AddActivations(task_call_scope)

    ret = py_utils.NestedMap()
    for k, v in activations.items():
      if ids_map.Get(k) is not None:
        if k in self._sequence_features:
          ret.Set(k, v)
        else:
          # Non-sequence embeddings, we fill the "time" dimension with 1.
          with tf.name_scope(k):
            ret.Set(k, tf.expand_dims(v, axis=[1]))
    return ret
示例#5
0
    def _ComputeLossesAndGradients(self, metrics, vmap):
        p = self.params
        vmap = self.GetTrainableVariables(vmap)

        # Get tpu embedding activations to compute the gradients for.
        tpu_embedding_activations = py_utils.NestedMap()
        tpu_embedding_graph_collection = py_utils.GetTpuEmbeddingGraphCollection(
        )
        if tpu_embedding_graph_collection:
            tpu_embedding_collection = tpu_embedding_graph_collection[0]
            task_call_scope = py_utils.GetTaskCallScope()
            tpu_embedding_activations = py_utils.NestedMap(
                tpu_embedding_collection.GetActivations(task_call_scope) or {})
            # It's possible that task_call_scope is None and its mode is not set in
            # tpu_embedding_collection (e.g. in unit test), but if the activation is
            # not empty, the mode must have been set.
            if tpu_embedding_activations and (
                    tpu_embedding_collection.ShouldStopGradient(
                        task_call_scope)):
                tpu_embedding_activations = py_utils.NestedMap()

        for v in vmap.Flatten():
            tf.logging.info('%s: bprop variable: %s', p.name, v.name)

        def LossAndGradients(metric_name):
            """Returns (loss, var_grads) computed from metrics[metric_name]."""
            metric = metrics.get(metric_name, None)
            if metric is None:
                raise ValueError('Loss %s not found in metrics %s' %
                                 (metric_name, list(metrics.keys())))
            # TODO(b/154785713): pass (loss, loss_weight) to ComputeGradients().
            loss = metric[0]
            return metric, self.optimizer.ComputeGradients(
                loss,
                vmap,
                p.grad_aggregation_method,
                p.colocate_gradients_with_ops,
                p.gate_gradients,
                compute_gradients_fn=self._CustomComputeGradientsFn(),
                skip_zero_gradients=p.skip_zero_gradients,
                skip_none_gradients=False,
                tpu_embedding_activations=tpu_embedding_activations)

        loss_name = p.loss_name or p.name
        losses = []
        eval_metrics = {}
        if isinstance(loss_name, (list, tuple)):
            assert not tpu_embedding_activations, (
                'TPU embedding does not support multiple loss currently.')
            losses_and_grads = {}
            variables = None
            for metric_name in loss_name:
                loss_metric, var_grads = LossAndGradients(metric_name)
                losses_and_grads[metric_name] = py_utils.NestedMap(
                    loss_metric=loss_metric,
                    grads=tf.nest.map_structure(lambda vg: vg.grad, var_grads))
                current_vars = tf.nest.map_structure(lambda vg: vg.var,
                                                     var_grads)
                if variables is None:
                    variables = current_vars
                else:
                    tf.nest.assert_same_structure(variables, current_vars)
                losses.append(loss_metric[0])

            grads, eval_metrics = self.gradient_combiner.Combine(
                variables, losses_and_grads)
            var_grads = tf.nest.map_structure(
                lambda v, g: py_utils.VarGrad(var=v, grad=g), variables, grads)
        else:
            loss_metric, var_grads = LossAndGradients(loss_name)
            losses.append(loss_metric[0])

        return losses, py_utils.SkipNoneGradients(var_grads), eval_metrics