def test_hashed_embedding_multi_dimension(self):
    with self.cached_session():
      embedding_weights = self._random_weights()
      values = constant_op.constant([["foo", "bar", "bar"],
                                     ["bar", "bar", "foo"]])
      sampled_candidates = constant_op.constant(
          [[[1, 3, 4, 6], [1, 7, 8, 9], [1, 7, 8, 9]],
           [[1, 7, 8, 9], [1, 7, 8, 9], [1, 3, 4, 6]]])

      embedding_lookup_result = (  # pylint: disable=protected-access
          embedding_ops._sampled_scattered_embedding_lookup(
              embedding_weights,
              values,
              sampled_candidates=sampled_candidates,
              hash_key=self._hash_key).eval())

      self.assertAllEqual(embedding_lookup_result.shape, [2, 3, 4])
      self.assertAllEqual(embedding_lookup_result[0][0],
                          embedding_lookup_result[1][2])

      invalid_indices = constant_op.constant([[[1, 3, 4, 6], [1, 7, 8, 9]],
                                              [[1, 7, 8, 9], [1, 7, 8, 9]]])
      with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, (
          r"\[The shape of sampled_candidates: \] \[2 2 4\] "
          r"\[ does not match the shape of values: \] \[2 3\]")):
        # pylint: disable=protected-access
        embedding_ops._sampled_scattered_embedding_lookup(
            embedding_weights, values,
            sampled_candidates=invalid_indices).eval()
    def test_hashed_embedding_multi_dimension(self):
        with self.test_session():
            embedding_weights = self._random_weights()
            values = constant_op.constant([["foo", "bar", "bar"],
                                           ["bar", "bar", "foo"]])
            sampled_candidates = constant_op.constant([[[1, 3, 4, 6],
                                                        [1, 7, 8, 9],
                                                        [1, 7, 8, 9]],
                                                       [[1, 7, 8, 9],
                                                        [1, 7, 8, 9],
                                                        [1, 3, 4, 6]]])

            embedding_lookup_result = (  # pylint: disable=protected-access
                embedding_ops._sampled_scattered_embedding_lookup(
                    embedding_weights,
                    values,
                    sampled_candidates=sampled_candidates,
                    hash_key=self._hash_key).eval())

            self.assertAllEqual(embedding_lookup_result.shape, [2, 3, 4])
            self.assertAllEqual(embedding_lookup_result[0][0],
                                embedding_lookup_result[1][2])

            invalid_indices = constant_op.constant([[[1, 3, 4, 6],
                                                     [1, 7, 8, 9]],
                                                    [[1, 7, 8, 9],
                                                     [1, 7, 8, 9]]])
            with self.assertRaisesRegexp(
                    errors_impl.InvalidArgumentError,
                (r"\[The shape of sampled_candidates: \] \[2 2 4\] "
                 r"\[ does not match the shape of values: \] \[2 3\]")):
                # pylint: disable=protected-access
                embedding_ops._sampled_scattered_embedding_lookup(
                    embedding_weights,
                    values,
                    sampled_candidates=invalid_indices).eval()
Example #3
0
  def test_hashed_embedding_consistency(self):
    with self.test_session():
      embedding_weights = self._random_weights()
      values = tf.constant(["foo", "foo"])
      # The first three sampled_candidates are equal, so the first three
      # embedding weights will be equal.
      sampled_candidates = tf.constant([[1, 3, 4, 6], [1, 3, 4, 7]])

      embedding_lookup_result = (  # pylint: disable=protected-access
          embedding_ops._sampled_scattered_embedding_lookup(
              embedding_weights, values,
              sampled_candidates=sampled_candidates,
              hash_key=self._hash_key).eval())

      self.assertAllEqual(embedding_lookup_result.shape, [2, 4])
      self.assertAllEqual(embedding_lookup_result[0][:3],
                          embedding_lookup_result[1][:3])
      self.assertNotEqual(embedding_lookup_result[0][3],
                          embedding_lookup_result[1][3])