def safe_sparse_lookup(): sp_ids = sparse_tensor.SparseTensor( indices=[[0, 0], [0, 1], [1, 0], [2, 2]], values=[0, -1, 4, 1], dense_shape=[3, 3]) sp_weights = sparse_tensor.SparseTensor( indices=[[0, 0], [0, 1], [1, 0], [2, 2]], values=[1., 1., -1., 1.], dense_shape=[3, 3]) return embedding_ops.safe_embedding_lookup_sparse_v2( sv, sp_ids, sp_weights)
def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self): with self.cached_session(): embedding_weights = self._random_weights() sparse_ids, sparse_weights = self._ids_and_weights_3d() embedding_lookup_result = ( embedding_ops.safe_embedding_lookup_sparse_v2( embedding_weights, sparse_ids, sparse_weights).eval()) self.assertAllClose(embedding_lookup_result, [[ (1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / 3.0, [0] * 4, [0] * 4 ], [embedding_weights[0][2], [0] * 4, [0] * 4]])
def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self): with self.cached_session(): embedding_weights = self._random_weights() sparse_ids, sparse_weights = self._ids_and_weights_3d() embedding_lookup_result = ( embedding_ops.safe_embedding_lookup_sparse_v2( embedding_weights, sparse_ids, sparse_weights).eval()) self.assertAllClose(embedding_lookup_result, [[ (1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / 3.0, [0] * 4, [0] * 4 ], [embedding_weights[0][2], [0] * 4, [0] * 4]])
def test_safe_embedding_lookup_sparse_partitioned(self): with self.cached_session(): embedding_weights = self._random_weights(num_shards=3) sparse_ids, _ = self._ids_and_weights_2d() embedding_lookup_result = ( embedding_ops.safe_embedding_lookup_sparse_v2(embedding_weights, sparse_ids, None)) embedding_weights = list(itertools.chain(*embedding_weights)) self.assertAllClose(embedding_lookup_result, [(embedding_weights[0] + embedding_weights[1]) / 2.0, [0] * 4, [0] * 4, embedding_weights[2], (embedding_weights[0] + embedding_weights[1]) / 2.0])
def test_safe_embedding_lookup_sparse_no_weights(self): with self.cached_session(): embedding_weights = self._random_weights() sparse_ids, _ = self._ids_and_weights_2d() embedding_lookup_result = ( embedding_ops.safe_embedding_lookup_sparse_v2(embedding_weights, sparse_ids, None)) self.assertAllClose( embedding_lookup_result, [(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4, [0] * 4, embedding_weights[0][2], ( embedding_weights[0][0] + embedding_weights[0][1]) / 2.0])
def test_safe_embedding_lookup_sparse_return_special_vector(self): with self.cached_session(): embedding_weights = self._random_weights() sparse_ids, sparse_weights = self._ids_and_weights_2d() embedding_lookup_result = ( embedding_ops.safe_embedding_lookup_sparse_v2( embedding_weights, sparse_ids, sparse_weights, default_id=3)) self.assertAllClose( embedding_lookup_result, [(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / 3.0, embedding_weights[0][3], embedding_weights[0][3], embedding_weights[0][2], embedding_weights[0][3]])
def test_safe_embedding_lookup_sparse_partitioned(self): with self.cached_session(): embedding_weights = self._random_weights(num_shards=3) sparse_ids, _ = self._ids_and_weights_2d() embedding_lookup_result = ( embedding_ops.safe_embedding_lookup_sparse_v2( embedding_weights, sparse_ids, None).eval()) embedding_weights = list(itertools.chain(*embedding_weights)) self.assertAllClose(embedding_lookup_result, [(embedding_weights[0] + embedding_weights[1]) / 2.0, [0] * 4, [0] * 4, embedding_weights[2], (embedding_weights[0] + embedding_weights[1]) / 2.0])
def test_safe_embedding_lookup_sparse_no_weights(self): with self.cached_session(): embedding_weights = self._random_weights() sparse_ids, _ = self._ids_and_weights_2d() embedding_lookup_result = ( embedding_ops.safe_embedding_lookup_sparse_v2( embedding_weights, sparse_ids, None).eval()) self.assertAllClose( embedding_lookup_result, [(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4, [0] * 4, embedding_weights[0][2], ( embedding_weights[0][0] + embedding_weights[0][1]) / 2.0])
def test_safe_embedding_lookup_sparse_return_special_vector(self): with self.cached_session(): embedding_weights = self._random_weights() sparse_ids, sparse_weights = self._ids_and_weights_2d() embedding_lookup_result = ( embedding_ops.safe_embedding_lookup_sparse_v2( embedding_weights, sparse_ids, sparse_weights, default_id=3).eval()) self.assertAllClose( embedding_lookup_result, [(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / 3.0, embedding_weights[0][3], embedding_weights[0][3], embedding_weights[0][2], embedding_weights[0][3]])
def _embedding_lookup_for_sparse_tensor( inp: sparse_tensor.SparseTensor, weight: Optional[sparse_tensor.SparseTensor], table: tf_variables.Variable, feature: tpu_embedding_v2_utils.FeatureConfig) -> ops.Tensor: """Embedding lookup for sparse tensor based on its feature config. Args: inp: a single SparseTensor input. weight: None or SparseTensor which has the same shape of the input. table: a table variable. feature: a feature config. Returns: Embedding lookup result. """ if not feature.output_shape and feature.max_sequence_length > 0: batch_size = math_ops.cast(array_ops.shape(inp)[0], dtype=dtypes.int64) sparse_shape = array_ops.stack( [batch_size, feature.max_sequence_length], axis=0) # TPU Embedding truncates sequences to max_sequence_length, and if we # don't truncate, scatter_nd will error out if the index was out of # bounds. truncated_inp = sparse_ops.sparse_slice(inp, start=[0, 0], size=sparse_shape) dense_output_shape = array_ops.stack( [batch_size, feature.max_sequence_length, feature.table.dim], axis=0) return array_ops.scatter_nd( truncated_inp.indices, array_ops.gather(table.read_value(), truncated_inp.values), dense_output_shape) else: inp_rank = inp.dense_shape.get_shape()[0] if (not feature.validate_weights_and_indices and inp_rank is not None and inp_rank <= 2): return embedding_ops.embedding_lookup_sparse_v2( table, inp, sp_weights=weight, combiner=feature.table.combiner) else: return embedding_ops.safe_embedding_lookup_sparse_v2( table, inp, sparse_weights=weight, combiner=feature.table.combiner)