def _TpuEmbLookup(self) -> Dict[str, tf.Tensor]: """TPU Embedding lookup.""" activations = self._tpu_embedding.get_activations() task = py_utils.GetTaskCallScope() # We expect either None (if this is the first call) or a single item in a # list. tpu_embedding_activations = tf.get_collection( py_utils.TPU_EMBEDDING_ACTIVATIONS) if not tpu_embedding_activations: # Create a dict from task -> activations dict. tpu_embedding_activations_dict = {} tpu_embedding_activations_dict[task] = activations tf.add_to_collection(py_utils.TPU_EMBEDDING_ACTIVATIONS, tpu_embedding_activations_dict) else: # This is a subsequent call, so the dictionary already exists. tpu_embedding_activations_dict = tpu_embedding_activations[0] tpu_embedding_activations_dict[task] = activations ret = py_utils.NestedMap() for k, v in activations.items(): if k in self._sequence_features: ret[k] = v else: # Non-sequence embeddings, we fill the "time" dimension with 1. ret[k] = tf.expand_dims(v, axis=[1]) return ret
def Apply(self, metrics, vmap, gradient_mask=None, gradient_adjuster=None): """Computes updates on 'vmap' to optimize 'loss'. TODO(rpang): explore merging gradient_mask and gradient_adjuster. Args: metrics: A Dict[str, (value, weight)], from which loss can be extracted according to p.loss_name. vmap: A `.NestedMap` object containing variables to optimize. gradient_mask: if not None, a dict mapping variable names to a 0/1 scalar. gradient_adjuster: if not None, a function that mutates a given var_grads. Returns: (losses, op, eval_metrics), where - losses is a list of scalar tensors; - op is a tf.Operation to update variables; - eval_metrics is a Dict[str, (value, weight)], where each value/weight is a scalar tensor. """ # We apply gradients outside the name_scope to maintain backwards # compatibility on variables created by self.optimizer.Apply(). losses, var_grads, eval_metrics = self._ComputeLossesAndGradients( metrics, vmap) if 'tpu_embedding_var_grads' in var_grads: tpu_embedding_var_grads = var_grads.tpu_embedding_var_grads del var_grads.tpu_embedding_var_grads tpu_embedding_collection = py_utils.GetTpuEmbeddingGraphCollection( )[0] assert tpu_embedding_collection tpu_emb_update_op, stats = tpu_embedding_collection.ApplyGradients( py_utils.GetTaskCallScope(), tpu_embedding_var_grads.Transform( lambda var_grad: var_grad.grad)) eval_metrics.update(stats) else: tpu_emb_update_op = tf.no_op() assert py_utils.GetGlobalStep() is not None lr = self.LearningRate() var_grads, stats = self.AdjustGradients( var_grads, gradient_mask=gradient_mask, gradient_adjuster=gradient_adjuster) eval_metrics.update(stats) self._var_grads = var_grads eval_metrics['learning_rate'] = (tf.convert_to_tensor(lr), tf.convert_to_tensor(1.)) var_update_op = tf.group( [tpu_emb_update_op, self.optimizer.Apply(lr, var_grads)]) return losses, var_update_op, eval_metrics
def _TpuEmbLookup(self) -> Dict[str, tf.Tensor]: """TPU Embedding lookup.""" activations = self._tpu_embedding.get_activations() task = py_utils.GetTaskCallScope() self._tpu_embedding_collection.AddActivations(task, activations) ret = py_utils.NestedMap() for k, v in activations.items(): if k in self._sequence_features: ret[k] = v else: # Non-sequence embeddings, we fill the "time" dimension with 1. ret[k] = tf.expand_dims(v, axis=[1]) return ret
def _TpuEmbLookup(self, ids_map: py_utils.NestedMap) -> py_utils.NestedMap: """TPU Embedding lookup.""" task_call_scope = py_utils.GetTaskCallScope() activations = self._tpu_embedding_collection.AddActivations(task_call_scope) ret = py_utils.NestedMap() for k, v in activations.items(): if ids_map.Get(k) is not None: if k in self._sequence_features: ret.Set(k, v) else: # Non-sequence embeddings, we fill the "time" dimension with 1. with tf.name_scope(k): ret.Set(k, tf.expand_dims(v, axis=[1])) return ret
def _ComputeLossesAndGradients(self, metrics, vmap): p = self.params vmap = self.GetTrainableVariables(vmap) # Get tpu embedding activations to compute the gradients for. tpu_embedding_activations = py_utils.NestedMap() tpu_embedding_graph_collection = py_utils.GetTpuEmbeddingGraphCollection( ) if tpu_embedding_graph_collection: tpu_embedding_collection = tpu_embedding_graph_collection[0] task_call_scope = py_utils.GetTaskCallScope() tpu_embedding_activations = py_utils.NestedMap( tpu_embedding_collection.GetActivations(task_call_scope) or {}) # It's possible that task_call_scope is None and its mode is not set in # tpu_embedding_collection (e.g. in unit test), but if the activation is # not empty, the mode must have been set. if tpu_embedding_activations and ( tpu_embedding_collection.ShouldStopGradient( task_call_scope)): tpu_embedding_activations = py_utils.NestedMap() for v in vmap.Flatten(): tf.logging.info('%s: bprop variable: %s', p.name, v.name) def LossAndGradients(metric_name): """Returns (loss, var_grads) computed from metrics[metric_name].""" metric = metrics.get(metric_name, None) if metric is None: raise ValueError('Loss %s not found in metrics %s' % (metric_name, list(metrics.keys()))) # TODO(b/154785713): pass (loss, loss_weight) to ComputeGradients(). loss = metric[0] return metric, self.optimizer.ComputeGradients( loss, vmap, p.grad_aggregation_method, p.colocate_gradients_with_ops, p.gate_gradients, compute_gradients_fn=self._CustomComputeGradientsFn(), skip_zero_gradients=p.skip_zero_gradients, skip_none_gradients=False, tpu_embedding_activations=tpu_embedding_activations) loss_name = p.loss_name or p.name losses = [] eval_metrics = {} if isinstance(loss_name, (list, tuple)): assert not tpu_embedding_activations, ( 'TPU embedding does not support multiple loss currently.') losses_and_grads = {} variables = None for metric_name in loss_name: loss_metric, var_grads = LossAndGradients(metric_name) losses_and_grads[metric_name] = py_utils.NestedMap( loss_metric=loss_metric, grads=tf.nest.map_structure(lambda vg: vg.grad, var_grads)) current_vars = tf.nest.map_structure(lambda vg: vg.var, var_grads) if variables is None: variables = current_vars else: tf.nest.assert_same_structure(variables, current_vars) losses.append(loss_metric[0]) grads, eval_metrics = self.gradient_combiner.Combine( variables, losses_and_grads) var_grads = tf.nest.map_structure( lambda v, g: py_utils.VarGrad(var=v, grad=g), variables, grads) else: loss_metric, var_grads = LossAndGradients(loss_name) losses.append(loss_metric[0]) return losses, py_utils.SkipNoneGradients(var_grads), eval_metrics