def CpuEmbLookup(self, ids_map: Dict[str, tf.Tensor], partition_strategy: str) -> Dict[str, tf.Tensor]: """CPU evaluation embedding lookup for dense tensors. Args: ids_map: A dict of `input_key` string -> [batch, sequence] int32 Tensor. For sequence embeddings, -1 is used as a padding id. Non-sequence embeddings do not support padded ids. partition_strategy: See TPUEmbeddingLayer partition_strategy param. Returns: An activations dict of string -> float32 Tensor. For non-sequence embeddings: [batch, 1, embedding_dim] For sequence embeddings: [batch, max_sequence_length, embedding_dim] """ rets = py_utils.NestedMap() if self.max_sequence_length > 0: # "Sequence embedding", no combiner case for k, ids in ids_map.items(): rets[k] = self._SequenceEmbLookup(ids, partition_strategy) else: # Non-"Sequence embedding", combiner case for k, ids in ids_map.items(): # Dense to sparse. dense_shape = tf.shape(ids, out_type=tf.int64) sample_indices = tf.cast(tf.where(tf.not_equal(ids, -1)), tf.int64) embedding_indices = tf.cast(tf.gather_nd(ids, sample_indices), tf.int64) # [?, embedding_dim] sparse_ids = tf.SparseTensor( indices=sample_indices, values=embedding_indices, dense_shape=dense_shape) rets[k] = self._CombinerEmbLookup(sparse_ids, partition_strategy) return rets
def CpuEmbLookup(self, ids_map, partition_strategy): """CPU evaluation embedding lookup. Args: ids_map: A dict of `input_key` string -> [batch, sequence] int32 Tensor. -1 is used as a padding id. partition_strategy: See TPUEmbeddingLayer partition_strategy param. Returns: An activations dict of string -> float32 Tensor. For non-sequence embeddings: [batch, 1, embedding_dim] For sequence embeddings: [batch, max_sequence_length, embedding_dim] """ p = self.params rets = py_utils.NestedMap() if self.max_sequence_length > 0: # "Sequence embedding", no combiner case for k, ids in ids_map.items(): embs = tf.nn.embedding_lookup( self.theta.wm, tf.reshape(ids, [-1]), partition_strategy=partition_strategy) out_shape = tf.concat([tf.shape(ids), [p.embedding_dim]], 0) rets[k] = tf.reshape(embs, out_shape) else: # Non-"Sequence embedding", combiner case for k, ids in ids_map.items(): # Dense to sparse. dense_shape = tf.shape(ids, out_type=tf.int64) sample_indices = tf.cast(tf.where(tf.not_equal(ids, -1)), tf.int64) embedding_indices = tf.cast(tf.gather_nd(ids, sample_indices), tf.int64) sparse_ids = tf.SparseTensor(indices=sample_indices, values=embedding_indices, dense_shape=dense_shape) # [?, embedding_dim] # For tf.nn.embedding_lookup_sparse, output.dim0 might be different from # sparse_ids.dense_shape.dim0. # In fact, the '?' is the smallest span starting from the index=0 that # covers all the results. embs = tf.nn.embedding_lookup_sparse( self.theta.wm, sparse_ids, None, # sp_weights combiner=p.combiner, partition_strategy=partition_strategy) batch_size = dense_shape[0] # Explicitly pad results to maintain dim0=batch. dim0_padlen = tf.cast(batch_size, tf.int32) - tf.shape(embs)[0] embs = tf.pad(embs, [[0, dim0_padlen], [0, 0]]) # [batch, 1, embedding_dim] embs = py_utils.HasShape(embs, [batch_size], ndims=1) rets[k] = tf.expand_dims(embs, 1) return rets
def _Lookup(ids): # Dense to sparse. dense_shape = tf.shape(ids, out_type=tf.int64) sample_indices = tf.cast(tf.where(tf.not_equal(ids, -1)), tf.int64) embedding_indices = tf.cast(tf.gather_nd(ids, sample_indices), tf.int64) # [?, embedding_dim] sparse_ids = tf.SparseTensor( indices=sample_indices, values=embedding_indices, dense_shape=dense_shape) return self._CombinerEmbLookup(sparse_ids, partition_strategy)