def test_hashed_embedding_multi_dimension(self): with self.cached_session(): embedding_weights = self._random_weights() values = constant_op.constant([["foo", "bar", "bar"], ["bar", "bar", "foo"]]) sampled_candidates = constant_op.constant( [[[1, 3, 4, 6], [1, 7, 8, 9], [1, 7, 8, 9]], [[1, 7, 8, 9], [1, 7, 8, 9], [1, 3, 4, 6]]]) embedding_lookup_result = ( # pylint: disable=protected-access embedding_ops._sampled_scattered_embedding_lookup( embedding_weights, values, sampled_candidates=sampled_candidates, hash_key=self._hash_key).eval()) self.assertAllEqual(embedding_lookup_result.shape, [2, 3, 4]) self.assertAllEqual(embedding_lookup_result[0][0], embedding_lookup_result[1][2]) invalid_indices = constant_op.constant([[[1, 3, 4, 6], [1, 7, 8, 9]], [[1, 7, 8, 9], [1, 7, 8, 9]]]) with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, ( r"\[The shape of sampled_candidates: \] \[2 2 4\] " r"\[ does not match the shape of values: \] \[2 3\]")): # pylint: disable=protected-access embedding_ops._sampled_scattered_embedding_lookup( embedding_weights, values, sampled_candidates=invalid_indices).eval()
def test_hashed_embedding_multi_dimension(self): with self.test_session(): embedding_weights = self._random_weights() values = constant_op.constant([["foo", "bar", "bar"], ["bar", "bar", "foo"]]) sampled_candidates = constant_op.constant([[[1, 3, 4, 6], [1, 7, 8, 9], [1, 7, 8, 9]], [[1, 7, 8, 9], [1, 7, 8, 9], [1, 3, 4, 6]]]) embedding_lookup_result = ( # pylint: disable=protected-access embedding_ops._sampled_scattered_embedding_lookup( embedding_weights, values, sampled_candidates=sampled_candidates, hash_key=self._hash_key).eval()) self.assertAllEqual(embedding_lookup_result.shape, [2, 3, 4]) self.assertAllEqual(embedding_lookup_result[0][0], embedding_lookup_result[1][2]) invalid_indices = constant_op.constant([[[1, 3, 4, 6], [1, 7, 8, 9]], [[1, 7, 8, 9], [1, 7, 8, 9]]]) with self.assertRaisesRegexp( errors_impl.InvalidArgumentError, (r"\[The shape of sampled_candidates: \] \[2 2 4\] " r"\[ does not match the shape of values: \] \[2 3\]")): # pylint: disable=protected-access embedding_ops._sampled_scattered_embedding_lookup( embedding_weights, values, sampled_candidates=invalid_indices).eval()
def test_hashed_embedding_consistency(self): with self.test_session(): embedding_weights = self._random_weights() values = tf.constant(["foo", "foo"]) # The first three sampled_candidates are equal, so the first three # embedding weights will be equal. sampled_candidates = tf.constant([[1, 3, 4, 6], [1, 3, 4, 7]]) embedding_lookup_result = ( # pylint: disable=protected-access embedding_ops._sampled_scattered_embedding_lookup( embedding_weights, values, sampled_candidates=sampled_candidates, hash_key=self._hash_key).eval()) self.assertAllEqual(embedding_lookup_result.shape, [2, 4]) self.assertAllEqual(embedding_lookup_result[0][:3], embedding_lookup_result[1][:3]) self.assertNotEqual(embedding_lookup_result[0][3], embedding_lookup_result[1][3])