Ejemplo n.º 1
0
 def call(self, inputs):
   # (b/144500510) ragged.map_flat_values(sparse_cross_hashed, inputs) will
   # cause kernel failure. Investigate and find a more efficient implementation
   if all([ragged_tensor.is_ragged(inp) for inp in inputs]):
     inputs = [inp.to_sparse() if ragged_tensor.is_ragged(inp) else inp
               for inp in inputs]
     if self.num_bins is not None:
       output = sparse_ops.sparse_cross_hashed(
           inputs, num_buckets=self.num_bins)
     else:
       output = sparse_ops.sparse_cross(inputs)
     return ragged_tensor.RaggedTensor.from_sparse(output)
   if any([ragged_tensor.is_ragged(inp) for inp in inputs]):
     raise ValueError('Inputs must be either all `RaggedTensor`, or none of '
                      'them should be `RaggedTensor`, got {}'.format(inputs))
   sparse_output = False
   if any([isinstance(inp, sparse_tensor.SparseTensor) for inp in inputs]):
     sparse_output = True
   if self.num_bins is not None:
     output = sparse_ops.sparse_cross_hashed(
         inputs, num_buckets=self.num_bins)
   else:
     output = sparse_ops.sparse_cross(inputs)
   if not sparse_output:
     output = sparse_ops.sparse_tensor_to_dense(output)
   return output
Ejemplo n.º 2
0
    def testRaggedCross(self,
                        inputs,
                        num_buckets=0,
                        hash_key=None,
                        expected=None,
                        expected_hashed=None,
                        matches_sparse_cross=True):
        ragged_cross = ragged_array_ops.cross(inputs)
        ragged_cross_hashed = ragged_array_ops.cross_hashed(
            inputs, num_buckets, hash_key)

        if expected is not None:
            self.assertAllEqual(ragged_cross, expected)
        if expected_hashed is not None:
            self.assertAllEqual(ragged_cross_hashed, expected_hashed)

        if matches_sparse_cross:
            # Check that ragged.cross & sparse.cross match.
            sparse_inputs = [self._ragged_to_sparse(t) for t in inputs]
            sparse_cross = sparse_ops.sparse_cross(sparse_inputs)
            self.assertAllEqual(
                ragged_cross,
                ragged_tensor.RaggedTensor.from_sparse(sparse_cross))

            # Check that ragged.cross_hashed & sparse.cross_hashed match.
            sparse_inputs = [self._ragged_to_sparse(t) for t in inputs]
            sparse_cross_hashed = sparse_ops.sparse_cross_hashed(
                sparse_inputs, num_buckets, hash_key)
            self.assertAllEqual(
                ragged_cross_hashed,
                ragged_tensor.RaggedTensor.from_sparse(sparse_cross_hashed))
 def test_hashed_zero_bucket_no_hash_key(self):
     op = sparse_ops.sparse_cross_hashed([
         self._sparse_tensor([['batch1-FC1-F1']]),
         self._sparse_tensor([['batch1-FC2-F1']]),
         self._sparse_tensor([['batch1-FC3-F1']])
     ])
     # Check actual hashed output to prevent unintentional hashing changes.
     expected_out = self._sparse_tensor([[1971693436396284976]])
     with self.cached_session() as sess:
         self._assert_sparse_tensor_equals(expected_out, sess.run(op))
Ejemplo n.º 4
0
 def test_hashed_zero_bucket_no_hash_key(self):
   op = sparse_ops.sparse_cross_hashed([
       self._sparse_tensor([['batch1-FC1-F1']]),
       self._sparse_tensor([['batch1-FC2-F1']]),
       self._sparse_tensor([['batch1-FC3-F1']])
   ])
   # Check actual hashed output to prevent unintentional hashing changes.
   expected_out = self._sparse_tensor([[1971693436396284976]])
   with self.cached_session() as sess:
     self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 def test_hashed_no_hash_key(self):
     op = sparse_ops.sparse_cross_hashed([
         self._sparse_tensor([['batch1-FC1-F1']]),
         self._sparse_tensor([['batch1-FC2-F1']]),
         self._sparse_tensor([['batch1-FC3-F1']])
     ],
                                         num_buckets=100)
     # Check actual hashed output to prevent unintentional hashing changes.
     expected_out = self._sparse_tensor([[83]])
     with self.test_session() as sess:
         self._assert_sparse_tensor_equals(expected_out, sess.run(op))
Ejemplo n.º 6
0
 def call(self, inputs):
   sparse_output = False
   if any([isinstance(inp, sparse_tensor.SparseTensor) for inp in inputs]):
     sparse_output = True
   if self.num_bins is not None:
     output = sparse_ops.sparse_cross_hashed(
         inputs, num_buckets=self.num_bins)
   else:
     output = sparse_ops.sparse_cross(inputs)
   if not sparse_output:
     output = sparse_ops.sparse_tensor_to_dense(output)
   return output
 def test_hashed_zero_bucket(self):
     op = sparse_ops.sparse_cross_hashed(
         [
             self._sparse_tensor([['batch1-FC1-F1']]),
             self._sparse_tensor([['batch1-FC2-F1']]),
             self._sparse_tensor([['batch1-FC3-F1']])
         ],
         hash_key=sparse_ops._DEFAULT_HASH_KEY + 1)
     # Check actual hashed output to prevent unintentional hashing changes.
     expected_out = self._sparse_tensor([[4847552627144134031]])
     with self.cached_session() as sess:
         self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 def test_hashed__has_no_collision(self):
   """Tests that fingerprint concatenation has no collisions."""
   # Although the last 10 bits of 359 and 1024+359 are identical.
   # As a result, all the crosses shouldn't collide.
   t1 = constant_op.constant([[359], [359 + 1024]])
   t2 = constant_op.constant([list(range(10)), list(range(10))])
   cross = sparse_ops.sparse_cross_hashed(
       [t2, t1], num_buckets=1024, hash_key=sparse_ops._DEFAULT_HASH_KEY + 1)
   cross_dense = sparse_ops.sparse_tensor_to_dense(cross)
   with session.Session():
     values = cross_dense.eval()
     self.assertTrue(numpy.not_equal(values[0], values[1]).all())
Ejemplo n.º 9
0
 def test_hashed_zero_bucket(self):
   op = sparse_ops.sparse_cross_hashed(
       [
           self._sparse_tensor([['batch1-FC1-F1']]),
           self._sparse_tensor([['batch1-FC2-F1']]),
           self._sparse_tensor([['batch1-FC3-F1']])
       ],
       hash_key=sparse_ops._DEFAULT_HASH_KEY + 1)
   # Check actual hashed output to prevent unintentional hashing changes.
   expected_out = self._sparse_tensor([[4847552627144134031]])
   with self.cached_session() as sess:
     self._assert_sparse_tensor_equals(expected_out, sess.run(op))
Ejemplo n.º 10
0
 def test_hashed_no_hash_key(self):
   op = sparse_ops.sparse_cross_hashed(
       [
           self._sparse_tensor([['batch1-FC1-F1']]),
           self._sparse_tensor([['batch1-FC2-F1']]),
           self._sparse_tensor([['batch1-FC3-F1']])
       ],
       num_buckets=100)
   # Check actual hashed output to prevent unintentional hashing changes.
   expected_out = self._sparse_tensor([[83]])
   with self.test_session() as sess:
     self._assert_sparse_tensor_equals(expected_out, sess.run(op))
Ejemplo n.º 11
0
 def test_hashed_output(self):
     op = sparse_ops.sparse_cross_hashed(
         [
             self._sparse_tensor([['batch1-FC1-F1']]),
             self._sparse_tensor([['batch1-FC2-F1']]),
             self._sparse_tensor([['batch1-FC3-F1']])
         ],
         num_buckets=100,
         hash_key=sparse_ops._DEFAULT_HASH_KEY + 1)
     # Check actual hashed output to prevent unintentional hashing changes.
     expected_out = self._sparse_tensor([[31]])
     with self.cached_session() as sess:
         self._assert_sparse_tensor_equals(expected_out, self.evaluate(op))
Ejemplo n.º 12
0
 def test_hashed_output(self):
   op = sparse_ops.sparse_cross_hashed(
       [
           self._sparse_tensor([['batch1-FC1-F1']]),
           self._sparse_tensor([['batch1-FC2-F1']]),
           self._sparse_tensor([['batch1-FC3-F1']])
       ],
       num_buckets=100,
       hash_key=sparse_ops._DEFAULT_HASH_KEY + 1)
   # Check actual hashed output to prevent unintentional hashing changes.
   expected_out = self._sparse_tensor([[31]])
   with self.cached_session() as sess:
     self._assert_sparse_tensor_equals(expected_out, self.evaluate(op))
Ejemplo n.º 13
0
    def partial_crossing(self, partial_inputs, ragged_out, sparse_out):
        """Gets the crossed output from a partial list/tuple of inputs."""
        if self.num_bins is not None:
            partial_output = sparse_ops.sparse_cross_hashed(
                partial_inputs, num_buckets=self.num_bins)
        else:
            partial_output = sparse_ops.sparse_cross(partial_inputs)

        # If ragged_out=True, convert output from sparse to ragged.
        if ragged_out:
            return ragged_tensor.RaggedTensor.from_sparse(partial_output)
        elif sparse_out:
            return partial_output
        else:
            return sparse_ops.sparse_tensor_to_dense(partial_output)
Ejemplo n.º 14
0
 def test_hashed_3x1x2(self):
     """Tests 3x1x2 permutation with hashed output."""
     op = sparse_ops.sparse_cross_hashed([
         self._sparse_tensor(
             [['batch1-FC1-F1', 'batch1-FC1-F2', 'batch1-FC1-F3']]),
         self._sparse_tensor([['batch1-FC2-F1']]),
         self._sparse_tensor([['batch1-FC3-F1', 'batch1-FC3-F2']])
     ],
                                         num_buckets=1000)
     with self.cached_session() as sess:
         out = sess.run(op)
         self.assertEqual(6, len(out.values))
         self.assertAllEqual([[0, i] for i in range(6)], out.indices)
         self.assertTrue(all(x < 1000 and x >= 0 for x in out.values))
         all_values_are_different = len(out.values) == len(set(out.values))
         self.assertTrue(all_values_are_different)
Ejemplo n.º 15
0
 def test_hashed_3x1x2(self):
   """Tests 3x1x2 permutation with hashed output."""
   op = sparse_ops.sparse_cross_hashed(
       [
           self._sparse_tensor(
               [['batch1-FC1-F1', 'batch1-FC1-F2', 'batch1-FC1-F3']]),
           self._sparse_tensor([['batch1-FC2-F1']]),
           self._sparse_tensor([['batch1-FC3-F1', 'batch1-FC3-F2']])
       ],
       num_buckets=1000)
   with self.cached_session() as sess:
     out = sess.run(op)
     self.assertEqual(6, len(out.values))
     self.assertAllEqual([[0, i] for i in range(6)], out.indices)
     self.assertTrue(all(x < 1000 and x >= 0 for x in out.values))
     all_values_are_different = len(out.values) == len(set(out.values))
     self.assertTrue(all_values_are_different)