コード例 #1
0
    def testRaggedCross(self,
                        inputs,
                        num_buckets=0,
                        hash_key=None,
                        expected=None,
                        expected_hashed=None,
                        matches_sparse_cross=True):
        ragged_cross = ragged_array_ops.cross(inputs)
        ragged_cross_hashed = ragged_array_ops.cross_hashed(
            inputs, num_buckets, hash_key)

        if expected is not None:
            self.assertAllEqual(ragged_cross, expected)
        if expected_hashed is not None:
            self.assertAllEqual(ragged_cross_hashed, expected_hashed)

        if matches_sparse_cross:
            # Check that ragged.cross & sparse.cross match.
            sparse_inputs = [self._ragged_to_sparse(t) for t in inputs]
            sparse_cross = sparse_ops.sparse_cross(sparse_inputs)
            self.assertAllEqual(
                ragged_cross,
                ragged_tensor.RaggedTensor.from_sparse(sparse_cross))

            # Check that ragged.cross_hashed & sparse.cross_hashed match.
            sparse_inputs = [self._ragged_to_sparse(t) for t in inputs]
            sparse_cross_hashed = sparse_ops.sparse_cross_hashed(
                sparse_inputs, num_buckets, hash_key)
            self.assertAllEqual(
                ragged_cross_hashed,
                ragged_tensor.RaggedTensor.from_sparse(sparse_cross_hashed))
コード例 #2
0
 def partial_crossing(self, partial_inputs, ragged_out, sparse_out):
   """Gets the crossed output from a partial list/tuple of inputs."""
   # If ragged_out=True, convert output from sparse to ragged.
   if ragged_out:
     return ragged_array_ops.cross(partial_inputs)
   elif sparse_out:
     return sparse_ops.sparse_cross(partial_inputs)
   else:
     return sparse_ops.sparse_tensor_to_dense(
         sparse_ops.sparse_cross(partial_inputs))
コード例 #3
0
 def partial_crossing(self, partial_inputs, ragged_out, sparse_out):
   """Gets the crossed output from a partial list/tuple of inputs."""
   # If ragged_out=True, convert output from sparse to ragged.
   if ragged_out:
     # TODO(momernick): Support separator with ragged_cross.
     if self.separator != '_X_':
       raise ValueError('Non-default separator with ragged input is not '
                        'supported yet, given {}'.format(self.separator))
     return ragged_array_ops.cross(partial_inputs)
   elif sparse_out:
     return sparse_ops.sparse_cross(partial_inputs, separator=self.separator)
   else:
     return sparse_ops.sparse_tensor_to_dense(
         sparse_ops.sparse_cross(partial_inputs, separator=self.separator))
コード例 #4
0
    def testRaggedCrossLargeBatch(self):
        batch_size = 5000
        inputs = [
            ragged_const([[1, 2, 3]] * batch_size),
            ragged_const([[b'4']] * batch_size),
            dense_const([[5]] * batch_size),
            sparse_const([[6, 7]] * batch_size)
        ]

        expected = [[
            b'1_X_4_X_5_X_6', b'1_X_4_X_5_X_7', b'2_X_4_X_5_X_6',
            b'2_X_4_X_5_X_7', b'3_X_4_X_5_X_6', b'3_X_4_X_5_X_7'
        ]] * batch_size

        ragged_cross = ragged_array_ops.cross(inputs)

        # Note: we don't use assertAllEqual here because if they don't match,
        # then the code in assertAllEqual that tries to build the error message
        # is very slow, causing the test to timeout.
        # pylint: disable=g-generic-assert
        self.assertTrue(self.evaluate(ragged_cross).to_list() == expected)
コード例 #5
0
 def testRuntimeError(self,
                      inputs,
                      exception=errors.InvalidArgumentError,
                      message=None):
     with self.assertRaisesRegex(exception, message):
         self.evaluate(ragged_array_ops.cross(inputs))
コード例 #6
0
 def testStaticError(self, inputs, exception=ValueError, message=None):
     with self.assertRaisesRegex(exception, message):
         ragged_array_ops.cross(inputs)
コード例 #7
0
 def fn(x):
     return ragged_array_ops.cross(x)