コード例 #1
0
  def CpuEmbLookup(self, ids_map: Dict[str, tf.Tensor],
                   partition_strategy: str) -> Dict[str, tf.Tensor]:
    """CPU evaluation embedding lookup for dense tensors.

    Args:
      ids_map: A dict of `input_key` string -> [batch, sequence] int32 Tensor.
        For sequence embeddings, -1 is used as a padding id. Non-sequence
        embeddings do not support padded ids.
      partition_strategy: See TPUEmbeddingLayer partition_strategy param.

    Returns:
      An activations dict of string -> float32 Tensor.
      For non-sequence embeddings: [batch, 1, embedding_dim]
      For sequence embeddings: [batch, max_sequence_length, embedding_dim]
    """
    rets = py_utils.NestedMap()
    if self.max_sequence_length > 0:
      # "Sequence embedding", no combiner case
      for k, ids in ids_map.items():
        rets[k] = self._SequenceEmbLookup(ids, partition_strategy)
    else:
      # Non-"Sequence embedding", combiner case
      for k, ids in ids_map.items():
        # Dense to sparse.
        dense_shape = tf.shape(ids, out_type=tf.int64)
        sample_indices = tf.cast(tf.where(tf.not_equal(ids, -1)), tf.int64)
        embedding_indices = tf.cast(tf.gather_nd(ids, sample_indices), tf.int64)
        # [?, embedding_dim]
        sparse_ids = tf.SparseTensor(
            indices=sample_indices,
            values=embedding_indices,
            dense_shape=dense_shape)
        rets[k] = self._CombinerEmbLookup(sparse_ids, partition_strategy)
    return rets
コード例 #2
0
    def CreateTpuEmbeddingEnqueueOps(self):
        """Creates the TpuEmbedding enqueue ops on the host.

    Note that this must be called after the instantiation of the
    monolithic TPUEmbeddingLayer.
    """
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING)
        tpu_embedding = (tpu_embedding_collection[0]
                         if tpu_embedding_collection else None)

        enqueue_ops = []

        if num_tpu_hosts > 1 and tpu_embedding is not None:
            if not p.use_per_host_infeed:
                tf.logging.fatal(
                    'TPU Embedding must be used with per_host_infeed with multiple '
                    'TPU host topologies.')
        tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys())
                              if tpu_embedding is not None else [])
        tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)
        if not tpu_embedding:
            return

        for task_id in range(num_infeed_hosts):
            host_device = '/task:{}/device:CPU:0'.format(task_id)
            with tf.device(host_device):
                if isinstance(self._batch, py_utils.NestedMap):
                    # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                    # Note that when MultiTaskData is used, bucket_keys will be at the
                    # second level of the dictionary.
                    self._batch = self._batch.FilterKeyVal(
                        lambda k, _: not k.endswith('bucket_keys'))
                tf.logging.info('host_device: %s, batch: %r', host_device,
                                self._batch)

                enqueue_dict_per_core = [
                    {} for _ in range(tpu_embedding.num_cores_per_host)
                ]
                num_cores_per_host = tpu_embedding.num_cores_per_host
                for key in tpu_emb_input_keys:
                    feat = self._batch[key]
                    tpu_emb_feat_splitted = tf.split(feat, num_cores_per_host)
                    for core, split in enumerate(tpu_emb_feat_splitted):
                        # Dense to sparse. Note the assumption of a padding id.
                        sample_indices = tf.where(tf.not_equal(split, -1))
                        embedding_indices = tf.gather_nd(split, sample_indices)
                        enqueue_data = tpu_embedding_lib.EnqueueData(
                            embedding_indices, sample_indices)
                        enqueue_dict_per_core[core][key] = enqueue_data
                enqueue_ops += tpu_embedding.generate_enqueue_ops(
                    enqueue_dict_per_core)
        self._tpu_infeed_op.append(tf.group(*enqueue_ops))
コード例 #3
0
    def CpuEmbLookup(self, ids_map, partition_strategy):
        """CPU evaluation embedding lookup.

    Args:
      ids_map: A dict of `input_key` string -> [batch, sequence] int32 Tensor.
        -1 is used as a padding id.
      partition_strategy: See TPUEmbeddingLayer partition_strategy param.

    Returns:
      An activations dict of string -> float32 Tensor.
      For non-sequence embeddings: [batch, 1, embedding_dim]
      For sequence embeddings: [batch, max_sequence_length, embedding_dim]

    """
        p = self.params
        rets = py_utils.NestedMap()
        if self.max_sequence_length > 0:
            # "Sequence embedding", no combiner case
            for k, ids in ids_map.items():
                embs = tf.nn.embedding_lookup(
                    self.theta.wm,
                    tf.reshape(ids, [-1]),
                    partition_strategy=partition_strategy)
                out_shape = tf.concat([tf.shape(ids), [p.embedding_dim]], 0)
                rets[k] = tf.reshape(embs, out_shape)
        else:
            # Non-"Sequence embedding", combiner case
            for k, ids in ids_map.items():
                # Dense to sparse.
                dense_shape = tf.shape(ids, out_type=tf.int64)
                sample_indices = tf.cast(tf.where(tf.not_equal(ids, -1)),
                                         tf.int64)
                embedding_indices = tf.cast(tf.gather_nd(ids, sample_indices),
                                            tf.int64)
                sparse_ids = tf.SparseTensor(indices=sample_indices,
                                             values=embedding_indices,
                                             dense_shape=dense_shape)
                # [?, embedding_dim]
                # For tf.nn.embedding_lookup_sparse, output.dim0 might be different from
                # sparse_ids.dense_shape.dim0.
                # In fact, the '?' is the smallest span starting from the index=0 that
                # covers all the results.
                embs = tf.nn.embedding_lookup_sparse(
                    self.theta.wm,
                    sparse_ids,
                    None,  # sp_weights
                    combiner=p.combiner,
                    partition_strategy=partition_strategy)
                batch_size = dense_shape[0]
                # Explicitly pad results to maintain dim0=batch.
                dim0_padlen = tf.cast(batch_size, tf.int32) - tf.shape(embs)[0]
                embs = tf.pad(embs, [[0, dim0_padlen], [0, 0]])
                # [batch, 1, embedding_dim]
                embs = py_utils.HasShape(embs, [batch_size], ndims=1)
                rets[k] = tf.expand_dims(embs, 1)
        return rets
コード例 #4
0
ファイル: gshard_builder_test.py プロジェクト: vcj-huy/lingvo
 def _Notvisible(x):
     a, b = tf.expand_dims(x, -1), tf.expand_dims(x, -2)
     return tf.cast(
         tf.math.logical_or(
             tf.not_equal(a, b),
             # also ignoring segment_id=0
             tf.math.logical_not(
                 tf.math.logical_or(tf.cast(a, tf.bool),
                                    tf.cast(b, tf.bool)))),
         tf.float32)
コード例 #5
0
 def _Notvisible(seg_id, seg_pos):
   a, b = tf.expand_dims(seg_id, -1), tf.expand_dims(seg_id, -2)
   return tf.cast(
       tf.math.logical_or(
           tf.less(tf.expand_dims(seg_pos, -1), tf.expand_dims(seg_pos, -2)),
           tf.math.logical_or(
               tf.not_equal(a, b),
               tf.math.logical_not(
                   tf.math.logical_or(
                       tf.cast(a, tf.bool), tf.cast(b, tf.bool))))),
       tf.float32)
コード例 #6
0
ファイル: tpu_embedding_layers.py プロジェクト: Mddct/lingvo
 def _Lookup(ids):
   # Dense to sparse.
   dense_shape = tf.shape(ids, out_type=tf.int64)
   sample_indices = tf.cast(tf.where(tf.not_equal(ids, -1)), tf.int64)
   embedding_indices = tf.cast(tf.gather_nd(ids, sample_indices), tf.int64)
   # [?, embedding_dim]
   sparse_ids = tf.SparseTensor(
       indices=sample_indices,
       values=embedding_indices,
       dense_shape=dense_shape)
   return self._CombinerEmbLookup(sparse_ids, partition_strategy)
コード例 #7
0
def SequenceConcat(x, x_paddings, y, y_paddings, pad=0):
    """Concats sequence `x` with sequence `y`.

  This function is length aware (based off the paddings).

  Args:
    x: A sequence of tokens of shape [batch_size, x_len_max].
    x_paddings: The paddings of `x`.
    y: A sequence of tokens of shape [batch_size, y_len_max].
    y_paddings: The paddings of `y`.
    pad: The <pad> token to fill the concatenated sequence (of type integer).

  Returns:
    A tuple.
      - Concatenation of `x` and `y` of shape
        [batch_size, x_len_max + y_len_max].
      - Paddings of the concatenation of shape
        [batch_size, x_len_max + y_len_max].
  """
    # Get the length (w/ eos).
    x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32)
    y_len = tf.cast(tf.round(tf.reduce_sum(1 - y_paddings, 1)), tf.int32)

    batch_size = py_utils.GetShape(x)[0]
    y_len_max = py_utils.GetShape(y)[1]

    # Pad `x` with necessary <pad>.
    x = tf.concat([x, tf.fill(py_utils.GetShape(y), pad)], 1)
    # Replace all <pad> with 0.
    x = tf.where(tf.not_equal(x, pad), x, tf.fill(py_utils.GetShape(x), 0))

    # Compute the write indices of `y` in `xy`.
    indices = tf.stack([
        tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, y_len_max]),
        (tf.tile(tf.expand_dims(tf.range(y_len_max), 0), [batch_size, 1]) +
         tf.expand_dims(x_len, 1)),
    ], 2)

    xy = x + tf.scatter_nd(indices, y, py_utils.GetShape(x))

    # We need to remap all <pad> to `pad`.
    xy = tf.where(
        tf.less(tf.expand_dims(tf.range(py_utils.GetShape(xy)[1]), 0),
                tf.expand_dims(x_len + y_len, 1)), xy,
        tf.fill(py_utils.GetShape(xy), pad))
    xy_paddings = 1 - tf.sequence_mask(x_len + y_len,
                                       py_utils.GetShape(xy)[1],
                                       x_paddings.dtype)
    return xy, xy_paddings
コード例 #8
0
  def TransformFeatures(self, features):
    # We assume that the lasers are not padded, and all points are real.
    if ('points_padding' in features.lasers and
        features.lasers.points_padding is not None):
      raise ValueError('FilterNLZPoints preprocessor does not support '
                       'padded lasers.')

    # The 3rd feature in the laser is 1.0 for points in a no-label-zone
    # and -1. for normal points.
    is_not_nlz = tf.not_equal(features.lasers.points_feature[:, 2], 1.0)
    features.lasers.points_xyz = tf.boolean_mask(features.lasers.points_xyz,
                                                 is_not_nlz)
    features.lasers.points_feature = tf.boolean_mask(
        features.lasers.points_feature, is_not_nlz)
    return features
コード例 #9
0
ファイル: wpm_encoder.py プロジェクト: linhx13/lingvo
 def _ShouldMerge(unused_tokens, candidates):
     """Merge until not possible, or we abort early according to merge_prob."""
     return tf.logical_and(
         tf.reduce_any(tf.not_equal(candidates, NO_TOKEN)),
         tf.random.uniform([]) < self._merge_prob)
コード例 #10
0
    def CreateTpuFeeds(self):
        """Creates the TPU infeed queue from preprocessed batch."""
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info(
            'CreateTPUFeeds num_splits_per_client={} '
            'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'.
            format(cluster.num_splits_per_client,
                   cluster.num_devices_per_split, num_tpu_hosts,
                   p.use_per_host_infeed))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        with py_utils.outside_all_rewrites():
            assert py_utils.use_tpu()
            assert not self._made_tpu_infeed

            shards = tpu_function.get_tpu_context(
            ).number_of_shards // num_infeed_hosts
            tf.logging.info('shards {}'.format(shards))

            input_ops_list = []
            queues = []
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            tpu_embedding = (tpu_embedding_collection[0]
                             if tpu_embedding_collection else None)

            if num_tpu_hosts > 1 and tpu_embedding is not None:
                if not p.use_per_host_infeed:
                    tf.logging.fatal(
                        'TPU Embedding must be used with per_host_infeed with multiple '
                        'TPU host topologies.')
            tpu_emb_input_keys = (list(
                tpu_embedding.feature_to_config_dict.keys())
                                  if tpu_embedding is not None else [])
            tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)

            batch = None
            for task_id in range(num_infeed_hosts):
                host_device = '/task:{}/device:CPU:0'.format(task_id)
                with tf.device(host_device):
                    batch = self.GetPreprocessedInputBatch()
                    if isinstance(batch, py_utils.NestedMap):
                        # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                        # Note that when MultiTaskData is used, bucket_keys will be at the
                        # second level of the dictionary.
                        batch = batch.FilterKeyVal(
                            lambda k, _: not k.endswith('bucket_keys'))
                    tf.logging.info('host_device: %s, batch: %r', host_device,
                                    batch)

                    if tpu_embedding is not None:
                        enqueue_dict_per_core = [
                            {} for _ in range(tpu_embedding.num_cores_per_host)
                        ]
                        num_cores_per_host = tpu_embedding.num_cores_per_host
                        for key in tpu_emb_input_keys:
                            feat = batch[key]
                            tpu_emb_feat_splitted = tf.split(
                                feat, num_cores_per_host)
                            for core, split in enumerate(
                                    tpu_emb_feat_splitted):
                                # Dense to sparse. Note the assumption of a padding id.
                                sample_indices = tf.where(
                                    tf.not_equal(split, -1))
                                embedding_indices = tf.gather_nd(
                                    split, sample_indices)
                                enqueue_data = tpu_embedding_lib.EnqueueData(
                                    embedding_indices, sample_indices)
                                enqueue_dict_per_core[core][key] = enqueue_data
                        input_ops_list += tpu_embedding.generate_enqueue_ops(
                            enqueue_dict_per_core)

                    for k, x in batch.FlattenItems():
                        assert x.shape.is_fully_defined(), (
                            'Shape must be fully defined: %s: %s' % (k, x))
                        # TODO(cwhipkey): if it's a string (or other type not supported on
                        # TPU), drop it from feeding and on the other end add in an op that
                        # fails if used.
                    shapes = batch.Transform(lambda x: x.shape).Flatten()
                    dtypes = batch.Transform(lambda x: x.dtype).Flatten()
                    tf.logging.info('host_device: %s infeed shapes: %r',
                                    host_device, shapes)
                    tf.logging.info('host_device: %s infeed dtypes: %r',
                                    host_device, dtypes)
                    if p.use_partitioned_infeed_queue:
                        device_assignment = py_utils.GetTpuDeviceAssignment()

                        host_device = device_assignment.host_device(
                            replica=0, job=tf.flags.FLAGS.tf_master)
                        host_id = int(
                            host_device.split('/task:')[1].split('/device:')
                            [0])
                        tf.logging.info('host_id: {} host_device: {}'.format(
                            host_id, host_device))
                        q = tpu_feed._PartitionedInfeedQueue(  # pylint: disable=protected-access
                            number_of_tuple_elements=len(dtypes),
                            device_assignment=device_assignment,
                            host_id=host_id,
                            input_partition_dims=[[p.num_partitions, 1]
                                                  for _ in dtypes],
                            tuple_types=dtypes,
                            tuple_shapes=shapes)
                    else:
                        q = tpu_feed.InfeedQueue(tuple_types=dtypes,
                                                 tuple_shapes=shapes)
                        assert shards is not None
                        q.set_number_of_shards(shards)

                    queues.append(q)
                    tf.logging.info('q=%r', q)

                    if p.use_partitioned_infeed_queue:
                        input_ops = q.generate_enqueue_ops([batch.Flatten()])
                    elif p.use_per_host_infeed:
                        # TODO(ylc/zhifengc): Add this to a policy module and test it.
                        def TPUOrdinalFunction(shard_index_in_host):
                            device_assignment = py_utils.GetTpuDeviceAssignment(
                            )
                            if device_assignment:
                                # We put both enqueue/dequeue ops at core 0 in each replica.
                                replica = device_assignment.lookup_replicas(
                                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                                return device_assignment.tpu_ordinal(
                                    replica=replica)
                            else:
                                return shard_index_in_host

                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                            tpu_ordinal_function=TPUOrdinalFunction)
                    else:
                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            device_assignment=py_utils.GetTpuDeviceAssignment(
                            ))

                    input_ops_list += input_ops
            tf.logging.info('input_ops_list %s', input_ops_list)
            tpu_infeed_op = tf.group(*input_ops_list)
        self._made_tpu_infeed = True
        # Let trainer.py use multiple threads to drive the infeed op.
        for _ in range(p.tpu_infeed_parallelism):
            tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

        self._tpu_infeed_op = tpu_infeed_op

        with tf.device(tf.tpu.core(0)):
            tensors = queues[0].generate_dequeue_op()
        return batch.Pack(tensors)
コード例 #11
0
    def CreateTpuFeeds(self):
        """Creates the TPU infeed queue from preprocessed batch."""
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host))
        tf.logging.info('num_devices_per_split {}'.format(
            cluster.num_devices_per_split))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        with py_utils.outside_all_rewrites():
            assert py_utils.use_tpu()
            assert not self._made_tpu_infeed

            shards = tpu_function.get_tpu_context(
            ).number_of_shards // num_infeed_hosts
            input_ops_list = []
            queues = []
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            tpu_embedding = (tpu_embedding_collection[0]
                             if tpu_embedding_collection else None)

            tpu_emb_input_keys = (list(
                tpu_embedding.feature_to_config_dict.keys())
                                  if tpu_embedding is not None else [])
            tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)

            batch = None
            for task_id in range(num_infeed_hosts):
                host_device = '/task:{}/device:CPU:0'.format(task_id)
                with tf.device(host_device):
                    batch = self.GetPreprocessedInputBatch()
                    if 'bucket_keys' in batch:
                        # Hack: bucket_keys are not needed on TPU.
                        del batch['bucket_keys']
                    tf.logging.info('host_device: %s, batch: %r', host_device,
                                    batch)

                    if tpu_embedding is not None:
                        enqueue_dict_per_core = [
                            {} for _ in range(tpu_embedding.num_cores_per_host)
                        ]
                        num_cores_per_host = tpu_embedding.num_cores_per_host
                        for key in tpu_emb_input_keys:
                            feat = batch[key]
                            tpu_emb_feat_splitted = tf.split(
                                feat, num_cores_per_host)
                            for core, split in enumerate(
                                    tpu_emb_feat_splitted):
                                # Dense to sparse. Note the assumption of a padding id.
                                sample_indices = tf.where(
                                    tf.not_equal(split, -1))
                                embedding_indices = tf.gather_nd(
                                    split, sample_indices)
                                enqueue_data = tpu_embedding_lib.EnqueueData(
                                    embedding_indices, sample_indices)
                                enqueue_dict_per_core[core][key] = enqueue_data
                        input_ops_list += tpu_embedding.generate_enqueue_ops(
                            enqueue_dict_per_core)

                    for k, x in batch.FlattenItems():
                        assert x.shape.is_fully_defined(), (
                            'Shape must be fully defined: %s: %s' % (k, x))
                        # TODO(cwhipkey): if it's a string (or other type not supported on
                        # TPU), drop it from feeding and on the other end add in an op that
                        # fails if used.
                    shapes = batch.Transform(lambda x: x.shape).Flatten()
                    dtypes = batch.Transform(lambda x: x.dtype).Flatten()
                    tf.logging.info('host_device: %s infeed shapes: %r',
                                    host_device, shapes)
                    tf.logging.info('host_device: %s infeed dtypes: %r',
                                    host_device, dtypes)
                    q = tpu_feed.InfeedQueue(tuple_types=dtypes,
                                             tuple_shapes=shapes)
                    queues.append(q)
                    assert shards is not None
                    q.set_number_of_shards(shards)

                    if p.use_per_host_infeed:

                        # TODO(ylc/zhifengc): Add this to a policy module and test it.
                        def TPUOrdinalFunction(shard_index_in_host):
                            device_assignment = py_utils.GetTpuDeviceAssignment(
                            )
                            if device_assignment:
                                # We put both enqueue/dequeue ops at core 0 in each replica.
                                replica = device_assignment.lookup_replicas(
                                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                                return device_assignment.tpu_ordinal(
                                    replica=replica)
                            else:
                                return shard_index_in_host

                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                            tpu_ordinal_function=TPUOrdinalFunction)
                    else:
                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            device_assignment=py_utils.GetTpuDeviceAssignment(
                            ))

                    input_ops_list += input_ops
            tf.logging.info('input_ops_list %s', input_ops_list)
            tpu_infeed_op = tf.group(*input_ops_list)
        self._made_tpu_infeed = True
        # Let trainer.py use multiple threads to drive the infeed op.
        for _ in range(p.tpu_infeed_parallelism):
            tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

        # For executor-driven multiple programs, we need more fine-grained
        # access rather than using a single global graph collection.
        self.tpu_infeed_op = tpu_infeed_op

        with tf.device(tf.tpu.core(0)):
            tensors = queues[0].generate_dequeue_op()
        return batch.Pack(tensors)
コード例 #12
0
ファイル: moe_layers.py プロジェクト: apoorvakumar2306/lingvo
  def ComputeLoss(self, theta, activation, labels, segment_ids):
    p = self.params
    activation = self._MaybeSplit(
        self._MaybeSplit(activation) * (p.embedding_dim**-0.5))
    softmax_weights = theta.embedding
    if activation.dtype != softmax_weights.dtype:
      softmax_weights = tf.cast(softmax_weights, activation.dtype)
    logits = self._MaybeSplit(
        tf.einsum('BLM,VM->BLV', activation, softmax_weights))
    if p.logits_abs_max is not None:
      logits = self._MaybeSplit(
          py_utils.clip_by_value(logits, -p.logits_abs_max, p.logits_abs_max))

    off_value = p.label_smoothing / p.vocab_size
    on_value = 1.0 - p.label_smoothing + off_value
    soft_targets = self._MaybeSplit(
        tf.one_hot(
            labels, p.vocab_size, on_value=on_value, off_value=off_value))
    xent = self._MaybeSplit(
        tf.nn.softmax_cross_entropy_with_logits(
            labels=tf.one_hot(labels, p.vocab_size), logits=logits))
    loss = self._MaybeSplit(
        tf.nn.softmax_cross_entropy_with_logits(
            labels=soft_targets, logits=logits))
    soft_targets_xent = loss

    if p.z_loss_coef > 0.0:
      log_z = tf.math.reduce_logsumexp(logits, -1)
      z_loss_inc = p.z_loss_coef * tf.math.square(log_z)
      loss += z_loss_inc

    non_padding = self._MaybeSplit(
        tf.cast(tf.not_equal(segment_ids, 0), py_utils.FPropDtype(p)))

    per_token_loss = self._MaybeSplit(loss * non_padding)
    if p.z_loss_coef > 0.0:
      per_token_z_loss_inc = self._MaybeSplit(z_loss_inc * non_padding)

    if p.use_tgt_labels_size_as_loss_denominator:
      # E.g. loss is going to be tiny if inputs are not packed and only a
      # fraction of tgt_labels are non-padding.
      loss_denom = tf.reduce_sum(tf.ones_like(non_padding))
      per_example_loss_denom = tf.reduce_sum(tf.ones_like(non_padding), 1)
    else:
      loss_denom = tf.reduce_sum(non_padding)
      per_example_loss_denom = tf.reduce_sum(non_padding, 1)
    avg_loss = tf.reduce_sum(per_token_loss) / loss_denom
    avg_z_loss_inc = (tf.reduce_sum(per_token_z_loss_inc) /
                      loss_denom) if p.z_loss_coef > 0.0 else 0.0

    soft_targets_xent = (
        tf.reduce_sum(self._MaybeSplit(soft_targets_xent * non_padding)) /
        tf.reduce_sum(non_padding))

    # TODO(lepikhin): consider returning
    #   {'loss': (unnormalized per_token_loss, tf.reduce_sum(non_padding))}
    per_example_loss = {
        'loss': tf.reduce_sum(per_token_loss, 1) / per_example_loss_denom
    }
    return {
        'mean_xent': (tf.reduce_sum(self._MaybeSplit(xent * non_padding)) /
                      tf.reduce_sum(non_padding), tf.reduce_sum(non_padding)),
        'soft_targets_xent': (soft_targets_xent, tf.reduce_sum(non_padding)),
        'weight': (tf.reduce_sum(non_padding), 1.0),
        'loss': (avg_loss, 1.0),
        'avg_z_loss_inc': (avg_z_loss_inc, 1.0),
    }, per_example_loss