def testRaggedCross(self, inputs, num_buckets=0, hash_key=None, expected=None, expected_hashed=None, matches_sparse_cross=True): ragged_cross = ragged_array_ops.cross(inputs) ragged_cross_hashed = ragged_array_ops.cross_hashed( inputs, num_buckets, hash_key) if expected is not None: self.assertAllEqual(ragged_cross, expected) if expected_hashed is not None: self.assertAllEqual(ragged_cross_hashed, expected_hashed) if matches_sparse_cross: # Check that ragged.cross & sparse.cross match. sparse_inputs = [self._ragged_to_sparse(t) for t in inputs] sparse_cross = sparse_ops.sparse_cross(sparse_inputs) self.assertAllEqual( ragged_cross, ragged_tensor.RaggedTensor.from_sparse(sparse_cross)) # Check that ragged.cross_hashed & sparse.cross_hashed match. sparse_inputs = [self._ragged_to_sparse(t) for t in inputs] sparse_cross_hashed = sparse_ops.sparse_cross_hashed( sparse_inputs, num_buckets, hash_key) self.assertAllEqual( ragged_cross_hashed, ragged_tensor.RaggedTensor.from_sparse(sparse_cross_hashed))
def partial_crossing(self, partial_inputs, ragged_out, sparse_out): """Gets the crossed output from a partial list/tuple of inputs.""" # If ragged_out=True, convert output from sparse to ragged. if ragged_out: return ragged_array_ops.cross(partial_inputs) elif sparse_out: return sparse_ops.sparse_cross(partial_inputs) else: return sparse_ops.sparse_tensor_to_dense( sparse_ops.sparse_cross(partial_inputs))
def partial_crossing(self, partial_inputs, ragged_out, sparse_out): """Gets the crossed output from a partial list/tuple of inputs.""" # If ragged_out=True, convert output from sparse to ragged. if ragged_out: # TODO(momernick): Support separator with ragged_cross. if self.separator != '_X_': raise ValueError('Non-default separator with ragged input is not ' 'supported yet, given {}'.format(self.separator)) return ragged_array_ops.cross(partial_inputs) elif sparse_out: return sparse_ops.sparse_cross(partial_inputs, separator=self.separator) else: return sparse_ops.sparse_tensor_to_dense( sparse_ops.sparse_cross(partial_inputs, separator=self.separator))
def testRaggedCrossLargeBatch(self): batch_size = 5000 inputs = [ ragged_const([[1, 2, 3]] * batch_size), ragged_const([[b'4']] * batch_size), dense_const([[5]] * batch_size), sparse_const([[6, 7]] * batch_size) ] expected = [[ b'1_X_4_X_5_X_6', b'1_X_4_X_5_X_7', b'2_X_4_X_5_X_6', b'2_X_4_X_5_X_7', b'3_X_4_X_5_X_6', b'3_X_4_X_5_X_7' ]] * batch_size ragged_cross = ragged_array_ops.cross(inputs) # Note: we don't use assertAllEqual here because if they don't match, # then the code in assertAllEqual that tries to build the error message # is very slow, causing the test to timeout. # pylint: disable=g-generic-assert self.assertTrue(self.evaluate(ragged_cross).to_list() == expected)
def testRuntimeError(self, inputs, exception=errors.InvalidArgumentError, message=None): with self.assertRaisesRegex(exception, message): self.evaluate(ragged_array_ops.cross(inputs))
def testStaticError(self, inputs, exception=ValueError, message=None): with self.assertRaisesRegex(exception, message): ragged_array_ops.cross(inputs)
def fn(x): return ragged_array_ops.cross(x)