def safe_sparse_lookup():
   sp_ids = sparse_tensor.SparseTensor(
       indices=[[0, 0], [0, 1], [1, 0], [2, 2]],
       values=[0, -1, 4, 1],
       dense_shape=[3, 3])
   sp_weights = sparse_tensor.SparseTensor(
       indices=[[0, 0], [0, 1], [1, 0], [2, 2]],
       values=[1., 1., -1., 1.],
       dense_shape=[3, 3])
   return embedding_ops.safe_embedding_lookup_sparse_v2(
       sv, sp_ids, sp_weights)
  def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self):
    with self.cached_session():
      embedding_weights = self._random_weights()
      sparse_ids, sparse_weights = self._ids_and_weights_3d()

      embedding_lookup_result = (
          embedding_ops.safe_embedding_lookup_sparse_v2(
              embedding_weights, sparse_ids, sparse_weights).eval())

      self.assertAllClose(embedding_lookup_result, [[
          (1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / 3.0,
          [0] * 4, [0] * 4
      ], [embedding_weights[0][2], [0] * 4, [0] * 4]])
  def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self):
    with self.cached_session():
      embedding_weights = self._random_weights()
      sparse_ids, sparse_weights = self._ids_and_weights_3d()

      embedding_lookup_result = (
          embedding_ops.safe_embedding_lookup_sparse_v2(
              embedding_weights, sparse_ids, sparse_weights).eval())

      self.assertAllClose(embedding_lookup_result, [[
          (1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / 3.0,
          [0] * 4, [0] * 4
      ], [embedding_weights[0][2], [0] * 4, [0] * 4]])
Beispiel #4
0
  def test_safe_embedding_lookup_sparse_partitioned(self):
    with self.cached_session():
      embedding_weights = self._random_weights(num_shards=3)
      sparse_ids, _ = self._ids_and_weights_2d()

      embedding_lookup_result = (
          embedding_ops.safe_embedding_lookup_sparse_v2(embedding_weights,
                                                        sparse_ids, None))

      embedding_weights = list(itertools.chain(*embedding_weights))
      self.assertAllClose(embedding_lookup_result,
                          [(embedding_weights[0] + embedding_weights[1]) / 2.0,
                           [0] * 4, [0] * 4, embedding_weights[2],
                           (embedding_weights[0] + embedding_weights[1]) / 2.0])
Beispiel #5
0
  def test_safe_embedding_lookup_sparse_no_weights(self):
    with self.cached_session():
      embedding_weights = self._random_weights()
      sparse_ids, _ = self._ids_and_weights_2d()

      embedding_lookup_result = (
          embedding_ops.safe_embedding_lookup_sparse_v2(embedding_weights,
                                                        sparse_ids, None))

      self.assertAllClose(
          embedding_lookup_result,
          [(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4,
           [0] * 4, embedding_weights[0][2], (
               embedding_weights[0][0] + embedding_weights[0][1]) / 2.0])
Beispiel #6
0
  def test_safe_embedding_lookup_sparse_return_special_vector(self):
    with self.cached_session():
      embedding_weights = self._random_weights()
      sparse_ids, sparse_weights = self._ids_and_weights_2d()

      embedding_lookup_result = (
          embedding_ops.safe_embedding_lookup_sparse_v2(
              embedding_weights, sparse_ids, sparse_weights, default_id=3))

      self.assertAllClose(
          embedding_lookup_result,
          [(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) /
           3.0, embedding_weights[0][3], embedding_weights[0][3],
           embedding_weights[0][2], embedding_weights[0][3]])
  def test_safe_embedding_lookup_sparse_partitioned(self):
    with self.cached_session():
      embedding_weights = self._random_weights(num_shards=3)
      sparse_ids, _ = self._ids_and_weights_2d()

      embedding_lookup_result = (
          embedding_ops.safe_embedding_lookup_sparse_v2(
              embedding_weights, sparse_ids, None).eval())

      embedding_weights = list(itertools.chain(*embedding_weights))
      self.assertAllClose(embedding_lookup_result,
                          [(embedding_weights[0] + embedding_weights[1]) / 2.0,
                           [0] * 4, [0] * 4, embedding_weights[2],
                           (embedding_weights[0] + embedding_weights[1]) / 2.0])
  def test_safe_embedding_lookup_sparse_no_weights(self):
    with self.cached_session():
      embedding_weights = self._random_weights()
      sparse_ids, _ = self._ids_and_weights_2d()

      embedding_lookup_result = (
          embedding_ops.safe_embedding_lookup_sparse_v2(
              embedding_weights, sparse_ids, None).eval())

      self.assertAllClose(
          embedding_lookup_result,
          [(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4,
           [0] * 4, embedding_weights[0][2], (
               embedding_weights[0][0] + embedding_weights[0][1]) / 2.0])
  def test_safe_embedding_lookup_sparse_return_special_vector(self):
    with self.cached_session():
      embedding_weights = self._random_weights()
      sparse_ids, sparse_weights = self._ids_and_weights_2d()

      embedding_lookup_result = (
          embedding_ops.safe_embedding_lookup_sparse_v2(
              embedding_weights, sparse_ids, sparse_weights,
              default_id=3).eval())

      self.assertAllClose(
          embedding_lookup_result,
          [(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) /
           3.0, embedding_weights[0][3], embedding_weights[0][3],
           embedding_weights[0][2], embedding_weights[0][3]])
Beispiel #10
0
def _embedding_lookup_for_sparse_tensor(
        inp: sparse_tensor.SparseTensor,
        weight: Optional[sparse_tensor.SparseTensor],
        table: tf_variables.Variable,
        feature: tpu_embedding_v2_utils.FeatureConfig) -> ops.Tensor:
    """Embedding lookup for sparse tensor based on its feature config.

  Args:
    inp: a single SparseTensor input.
    weight: None or SparseTensor which has the same shape of the input.
    table: a table variable.
    feature: a feature config.

  Returns:
    Embedding lookup result.
  """
    if not feature.output_shape and feature.max_sequence_length > 0:
        batch_size = math_ops.cast(array_ops.shape(inp)[0], dtype=dtypes.int64)
        sparse_shape = array_ops.stack(
            [batch_size, feature.max_sequence_length], axis=0)
        # TPU Embedding truncates sequences to max_sequence_length, and if we
        # don't truncate, scatter_nd will error out if the index was out of
        # bounds.
        truncated_inp = sparse_ops.sparse_slice(inp,
                                                start=[0, 0],
                                                size=sparse_shape)

        dense_output_shape = array_ops.stack(
            [batch_size, feature.max_sequence_length, feature.table.dim],
            axis=0)
        return array_ops.scatter_nd(
            truncated_inp.indices,
            array_ops.gather(table.read_value(), truncated_inp.values),
            dense_output_shape)
    else:
        inp_rank = inp.dense_shape.get_shape()[0]
        if (not feature.validate_weights_and_indices and inp_rank is not None
                and inp_rank <= 2):
            return embedding_ops.embedding_lookup_sparse_v2(
                table, inp, sp_weights=weight, combiner=feature.table.combiner)
        else:
            return embedding_ops.safe_embedding_lookup_sparse_v2(
                table,
                inp,
                sparse_weights=weight,
                combiner=feature.table.combiner)