def call(self, inputs): # (b/144500510) ragged.map_flat_values(sparse_cross_hashed, inputs) will # cause kernel failure. Investigate and find a more efficient implementation if all([ragged_tensor.is_ragged(inp) for inp in inputs]): inputs = [inp.to_sparse() if ragged_tensor.is_ragged(inp) else inp for inp in inputs] if self.num_bins is not None: output = sparse_ops.sparse_cross_hashed( inputs, num_buckets=self.num_bins) else: output = sparse_ops.sparse_cross(inputs) return ragged_tensor.RaggedTensor.from_sparse(output) if any([ragged_tensor.is_ragged(inp) for inp in inputs]): raise ValueError('Inputs must be either all `RaggedTensor`, or none of ' 'them should be `RaggedTensor`, got {}'.format(inputs)) sparse_output = False if any([isinstance(inp, sparse_tensor.SparseTensor) for inp in inputs]): sparse_output = True if self.num_bins is not None: output = sparse_ops.sparse_cross_hashed( inputs, num_buckets=self.num_bins) else: output = sparse_ops.sparse_cross(inputs) if not sparse_output: output = sparse_ops.sparse_tensor_to_dense(output) return output
def testRaggedCross(self, inputs, num_buckets=0, hash_key=None, expected=None, expected_hashed=None, matches_sparse_cross=True): ragged_cross = ragged_array_ops.cross(inputs) ragged_cross_hashed = ragged_array_ops.cross_hashed( inputs, num_buckets, hash_key) if expected is not None: self.assertAllEqual(ragged_cross, expected) if expected_hashed is not None: self.assertAllEqual(ragged_cross_hashed, expected_hashed) if matches_sparse_cross: # Check that ragged.cross & sparse.cross match. sparse_inputs = [self._ragged_to_sparse(t) for t in inputs] sparse_cross = sparse_ops.sparse_cross(sparse_inputs) self.assertAllEqual( ragged_cross, ragged_tensor.RaggedTensor.from_sparse(sparse_cross)) # Check that ragged.cross_hashed & sparse.cross_hashed match. sparse_inputs = [self._ragged_to_sparse(t) for t in inputs] sparse_cross_hashed = sparse_ops.sparse_cross_hashed( sparse_inputs, num_buckets, hash_key) self.assertAllEqual( ragged_cross_hashed, ragged_tensor.RaggedTensor.from_sparse(sparse_cross_hashed))
def test_hashed_zero_bucket_no_hash_key(self): op = sparse_ops.sparse_cross_hashed([ self._sparse_tensor([['batch1-FC1-F1']]), self._sparse_tensor([['batch1-FC2-F1']]), self._sparse_tensor([['batch1-FC3-F1']]) ]) # Check actual hashed output to prevent unintentional hashing changes. expected_out = self._sparse_tensor([[1971693436396284976]]) with self.cached_session() as sess: self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_hashed_no_hash_key(self): op = sparse_ops.sparse_cross_hashed([ self._sparse_tensor([['batch1-FC1-F1']]), self._sparse_tensor([['batch1-FC2-F1']]), self._sparse_tensor([['batch1-FC3-F1']]) ], num_buckets=100) # Check actual hashed output to prevent unintentional hashing changes. expected_out = self._sparse_tensor([[83]]) with self.test_session() as sess: self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def call(self, inputs): sparse_output = False if any([isinstance(inp, sparse_tensor.SparseTensor) for inp in inputs]): sparse_output = True if self.num_bins is not None: output = sparse_ops.sparse_cross_hashed( inputs, num_buckets=self.num_bins) else: output = sparse_ops.sparse_cross(inputs) if not sparse_output: output = sparse_ops.sparse_tensor_to_dense(output) return output
def test_hashed_zero_bucket(self): op = sparse_ops.sparse_cross_hashed( [ self._sparse_tensor([['batch1-FC1-F1']]), self._sparse_tensor([['batch1-FC2-F1']]), self._sparse_tensor([['batch1-FC3-F1']]) ], hash_key=sparse_ops._DEFAULT_HASH_KEY + 1) # Check actual hashed output to prevent unintentional hashing changes. expected_out = self._sparse_tensor([[4847552627144134031]]) with self.cached_session() as sess: self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_hashed__has_no_collision(self): """Tests that fingerprint concatenation has no collisions.""" # Although the last 10 bits of 359 and 1024+359 are identical. # As a result, all the crosses shouldn't collide. t1 = constant_op.constant([[359], [359 + 1024]]) t2 = constant_op.constant([list(range(10)), list(range(10))]) cross = sparse_ops.sparse_cross_hashed( [t2, t1], num_buckets=1024, hash_key=sparse_ops._DEFAULT_HASH_KEY + 1) cross_dense = sparse_ops.sparse_tensor_to_dense(cross) with session.Session(): values = cross_dense.eval() self.assertTrue(numpy.not_equal(values[0], values[1]).all())
def test_hashed_no_hash_key(self): op = sparse_ops.sparse_cross_hashed( [ self._sparse_tensor([['batch1-FC1-F1']]), self._sparse_tensor([['batch1-FC2-F1']]), self._sparse_tensor([['batch1-FC3-F1']]) ], num_buckets=100) # Check actual hashed output to prevent unintentional hashing changes. expected_out = self._sparse_tensor([[83]]) with self.test_session() as sess: self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_hashed_output(self): op = sparse_ops.sparse_cross_hashed( [ self._sparse_tensor([['batch1-FC1-F1']]), self._sparse_tensor([['batch1-FC2-F1']]), self._sparse_tensor([['batch1-FC3-F1']]) ], num_buckets=100, hash_key=sparse_ops._DEFAULT_HASH_KEY + 1) # Check actual hashed output to prevent unintentional hashing changes. expected_out = self._sparse_tensor([[31]]) with self.cached_session() as sess: self._assert_sparse_tensor_equals(expected_out, self.evaluate(op))
def partial_crossing(self, partial_inputs, ragged_out, sparse_out): """Gets the crossed output from a partial list/tuple of inputs.""" if self.num_bins is not None: partial_output = sparse_ops.sparse_cross_hashed( partial_inputs, num_buckets=self.num_bins) else: partial_output = sparse_ops.sparse_cross(partial_inputs) # If ragged_out=True, convert output from sparse to ragged. if ragged_out: return ragged_tensor.RaggedTensor.from_sparse(partial_output) elif sparse_out: return partial_output else: return sparse_ops.sparse_tensor_to_dense(partial_output)
def test_hashed_3x1x2(self): """Tests 3x1x2 permutation with hashed output.""" op = sparse_ops.sparse_cross_hashed([ self._sparse_tensor( [['batch1-FC1-F1', 'batch1-FC1-F2', 'batch1-FC1-F3']]), self._sparse_tensor([['batch1-FC2-F1']]), self._sparse_tensor([['batch1-FC3-F1', 'batch1-FC3-F2']]) ], num_buckets=1000) with self.cached_session() as sess: out = sess.run(op) self.assertEqual(6, len(out.values)) self.assertAllEqual([[0, i] for i in range(6)], out.indices) self.assertTrue(all(x < 1000 and x >= 0 for x in out.values)) all_values_are_different = len(out.values) == len(set(out.values)) self.assertTrue(all_values_are_different)
def test_hashed_3x1x2(self): """Tests 3x1x2 permutation with hashed output.""" op = sparse_ops.sparse_cross_hashed( [ self._sparse_tensor( [['batch1-FC1-F1', 'batch1-FC1-F2', 'batch1-FC1-F3']]), self._sparse_tensor([['batch1-FC2-F1']]), self._sparse_tensor([['batch1-FC3-F1', 'batch1-FC3-F2']]) ], num_buckets=1000) with self.cached_session() as sess: out = sess.run(op) self.assertEqual(6, len(out.values)) self.assertAllEqual([[0, i] for i in range(6)], out.indices) self.assertTrue(all(x < 1000 and x >= 0 for x in out.values)) all_values_are_different = len(out.values) == len(set(out.values)) self.assertTrue(all_values_are_different)