예제 #1
0
  def CpuEmbLookup(self, ids_map: Dict[str, tf.Tensor],
                   partition_strategy: str) -> Dict[str, tf.Tensor]:
    """CPU evaluation embedding lookup for dense tensors.

    Args:
      ids_map: A dict of `input_key` string -> [batch, sequence] int32 Tensor.
        For sequence embeddings, -1 is used as a padding id. Non-sequence
        embeddings do not support padded ids.
      partition_strategy: See TPUEmbeddingLayer partition_strategy param.

    Returns:
      An activations dict of string -> float32 Tensor.
      For non-sequence embeddings: [batch, 1, embedding_dim]
      For sequence embeddings: [batch, max_sequence_length, embedding_dim]
    """
    rets = py_utils.NestedMap()
    if self.max_sequence_length > 0:
      # "Sequence embedding", no combiner case
      for k, ids in ids_map.items():
        rets[k] = self._SequenceEmbLookup(ids, partition_strategy)
    else:
      # Non-"Sequence embedding", combiner case
      for k, ids in ids_map.items():
        # Dense to sparse.
        dense_shape = tf.shape(ids, out_type=tf.int64)
        sample_indices = tf.cast(tf.where(tf.not_equal(ids, -1)), tf.int64)
        embedding_indices = tf.cast(tf.gather_nd(ids, sample_indices), tf.int64)
        # [?, embedding_dim]
        sparse_ids = tf.SparseTensor(
            indices=sample_indices,
            values=embedding_indices,
            dense_shape=dense_shape)
        rets[k] = self._CombinerEmbLookup(sparse_ids, partition_strategy)
    return rets
예제 #2
0
    def CpuEmbLookup(self, ids_map, partition_strategy):
        """CPU evaluation embedding lookup.

    Args:
      ids_map: A dict of `input_key` string -> [batch, sequence] int32 Tensor.
        -1 is used as a padding id.
      partition_strategy: See TPUEmbeddingLayer partition_strategy param.

    Returns:
      An activations dict of string -> float32 Tensor.
      For non-sequence embeddings: [batch, 1, embedding_dim]
      For sequence embeddings: [batch, max_sequence_length, embedding_dim]

    """
        p = self.params
        rets = py_utils.NestedMap()
        if self.max_sequence_length > 0:
            # "Sequence embedding", no combiner case
            for k, ids in ids_map.items():
                embs = tf.nn.embedding_lookup(
                    self.theta.wm,
                    tf.reshape(ids, [-1]),
                    partition_strategy=partition_strategy)
                out_shape = tf.concat([tf.shape(ids), [p.embedding_dim]], 0)
                rets[k] = tf.reshape(embs, out_shape)
        else:
            # Non-"Sequence embedding", combiner case
            for k, ids in ids_map.items():
                # Dense to sparse.
                dense_shape = tf.shape(ids, out_type=tf.int64)
                sample_indices = tf.cast(tf.where(tf.not_equal(ids, -1)),
                                         tf.int64)
                embedding_indices = tf.cast(tf.gather_nd(ids, sample_indices),
                                            tf.int64)
                sparse_ids = tf.SparseTensor(indices=sample_indices,
                                             values=embedding_indices,
                                             dense_shape=dense_shape)
                # [?, embedding_dim]
                # For tf.nn.embedding_lookup_sparse, output.dim0 might be different from
                # sparse_ids.dense_shape.dim0.
                # In fact, the '?' is the smallest span starting from the index=0 that
                # covers all the results.
                embs = tf.nn.embedding_lookup_sparse(
                    self.theta.wm,
                    sparse_ids,
                    None,  # sp_weights
                    combiner=p.combiner,
                    partition_strategy=partition_strategy)
                batch_size = dense_shape[0]
                # Explicitly pad results to maintain dim0=batch.
                dim0_padlen = tf.cast(batch_size, tf.int32) - tf.shape(embs)[0]
                embs = tf.pad(embs, [[0, dim0_padlen], [0, 0]])
                # [batch, 1, embedding_dim]
                embs = py_utils.HasShape(embs, [batch_size], ndims=1)
                rets[k] = tf.expand_dims(embs, 1)
        return rets
예제 #3
0
 def _Lookup(ids):
   # Dense to sparse.
   dense_shape = tf.shape(ids, out_type=tf.int64)
   sample_indices = tf.cast(tf.where(tf.not_equal(ids, -1)), tf.int64)
   embedding_indices = tf.cast(tf.gather_nd(ids, sample_indices), tf.int64)
   # [?, embedding_dim]
   sparse_ids = tf.SparseTensor(
       indices=sample_indices,
       values=embedding_indices,
       dense_shape=dense_shape)
   return self._CombinerEmbLookup(sparse_ids, partition_strategy)