def _assertRaises(self, diags, rhs, diags_format="compact"):
   pivoting = True
   if hasattr(self, "pivoting"):
     pivoting = self.pivoting
   with self.assertRaises(ValueError):
     linalg_impl.tridiagonal_solve(
         diags, rhs, diags_format, partial_pivoting=pivoting)
 def _assertRaises(self, diags, rhs, diags_format="compact"):
     # Skip tests for combinations with missing implementations.
     _, pivoting, perturb_singular = self._is_unimplemented()
     with self.assertRaises((NotImplementedError, ValueError)):
         linalg_impl.tridiagonal_solve(diags,
                                       rhs,
                                       diags_format,
                                       partial_pivoting=pivoting,
                                       perturb_singular=perturb_singular)
Пример #3
0
 def _assertRaises(self, diags, rhs, diags_format="compact"):
   pivoting = True
   if hasattr(self, "pivoting"):
     pivoting = self.pivoting
   if test_util.is_xla_enabled() and pivoting:
     # Pivoting is not supported by xla backends.
     return
   with self.assertRaises(ValueError):
     linalg_impl.tridiagonal_solve(
         diags, rhs, diags_format, partial_pivoting=pivoting)
    def _test(self,
              diags,
              rhs,
              expected,
              diags_format="compact",
              transpose_rhs=False,
              conjugate_rhs=False):
        with self.cached_session():
            # Skip tests for combinations with missing implementations.
            unimplemented, pivoting, perturb_singular = self._is_unimplemented(
            )
            if unimplemented:
                return

            result = linalg_impl.tridiagonal_solve(
                diags,
                rhs,
                diags_format,
                transpose_rhs,
                conjugate_rhs,
                partial_pivoting=pivoting,
                perturb_singular=perturb_singular)
            result = self.evaluate(result)
            if expected is None:
                self.assertAllEqual(np.zeros_like(result, dtype=np.bool),
                                    np.isfinite(result))
            else:
                self.assertAllClose(result, expected)
        def _benchmark(self, generate_data_fn, test_name_format_string):
            devices = [("/cpu:0", "cpu")]
            if test.is_gpu_available(cuda_only=True):
                devices += [("/gpu:0", "gpu")]

            for device_option, pivoting_option, size_option in \
                itertools.product(devices, self.pivoting_options, self.sizes):

                device_id, device_name = device_option
                pivoting, pivoting_name = pivoting_option
                matrix_size, batch_size, num_rhs = size_option

                with ops.Graph().as_default(), \
                    session.Session(config=benchmark.benchmark_config()) as sess, \
                    ops.device(device_id):
                    diags, rhs = generate_data_fn(matrix_size, batch_size,
                                                  num_rhs)
                    # Pivoting is not supported by XLA backends.
                    if test.is_xla_enabled() and pivoting:
                        return
                    x = linalg_impl.tridiagonal_solve(
                        diags, rhs, partial_pivoting=pivoting)
                    self.evaluate(variables.global_variables_initializer())
                    self.run_op_benchmark(sess,
                                          control_flow_ops.group(x),
                                          min_iters=10,
                                          store_memory_usage=False,
                                          name=test_name_format_string.format(
                                              device_name, matrix_size,
                                              batch_size, num_rhs,
                                              pivoting_name))
Пример #6
0
    def _gradientTest(
            self,
            diags,
            rhs,
            y,  # output = reduce_sum(y * tridiag_solve(diags, rhs))
            expected_grad_diags,  # expected gradient of output w.r.t. diags
            expected_grad_rhs,  # expected gradient of output w.r.t. rhs
            diags_format="compact",
            transpose_rhs=False,
            feed_dict=None):
        expected_grad_diags = np.array(expected_grad_diags).astype(np.float32)
        expected_grad_rhs = np.array(expected_grad_rhs).astype(np.float32)
        with self.session() as sess, self.test_scope():
            diags = _tfconst(diags)
            rhs = _tfconst(rhs)
            y = _tfconst(y)

            x = linalg_impl.tridiagonal_solve(diags,
                                              rhs,
                                              diagonals_format=diags_format,
                                              transpose_rhs=transpose_rhs,
                                              conjugate_rhs=False,
                                              partial_pivoting=False)

            res = math_ops.reduce_sum(x * y)
            actual_grad_diags = sess.run(gradient_ops.gradients(res, diags)[0],
                                         feed_dict=feed_dict)
            actual_rhs_diags = sess.run(gradient_ops.gradients(res, rhs)[0],
                                        feed_dict=feed_dict)
        self.assertAllClose(expected_grad_diags, actual_grad_diags)
        self.assertAllClose(expected_grad_rhs, actual_rhs_diags)
Пример #7
0
    def testPartialPivotingRaises(self):
        np.random.seed(0)
        batch_size = 8
        num_dims = 11
        num_rhs = 5

        diagonals_np = np.random.normal(size=(batch_size, 3,
                                              num_dims)).astype(np.float32)
        rhs_np = np.random.normal(size=(batch_size, num_dims,
                                        num_rhs)).astype(np.float32)

        with self.session() as sess, self.test_scope():
            with self.assertRaisesRegex(
                    errors_impl.UnimplementedError,
                    "Current implementation does not yet support pivoting."):
                diags = array_ops.placeholder(shape=(batch_size, 3, num_dims),
                                              dtype=dtypes.float32)
                rhs = array_ops.placeholder(shape=(batch_size, num_dims,
                                                   num_rhs),
                                            dtype=dtypes.float32)
                sess.run(linalg_impl.tridiagonal_solve(diags,
                                                       rhs,
                                                       partial_pivoting=True),
                         feed_dict={
                             diags: diagonals_np,
                             rhs: rhs_np
                         })
 def _test(self,
           diags,
           rhs,
           expected,
           diags_format="compact",
           transpose_rhs=False,
           conjugate_rhs=False):
     with self.cached_session():
         pivoting = True
         if hasattr(self, "pivoting"):
             pivoting = self.pivoting
         if test_util.is_xla_enabled() and pivoting:
             # Pivoting is not supported by xla backends.
             return
         result = linalg_impl.tridiagonal_solve(diags,
                                                rhs,
                                                diags_format,
                                                transpose_rhs,
                                                conjugate_rhs,
                                                partial_pivoting=pivoting)
         result = self.evaluate(result)
         if expected is None:
             self.assertAllEqual(np.zeros_like(result, dtype=np.bool),
                                 np.isfinite(result))
         else:
             self.assertAllClose(result, expected)
 def _gradientTest(
         self,
         diags,
         rhs,
         y,  # output = reduce_sum(y * tridiag_solve(diags, rhs))
         expected_grad_diags,  # expected gradient of output w.r.t. diags
         expected_grad_rhs,  # expected gradient of output w.r.t. rhs
         diags_format="compact",
         transpose_rhs=False,
         conjugate_rhs=False,
         feed_dict=None):
     expected_grad_diags = _tfconst(expected_grad_diags)
     expected_grad_rhs = _tfconst(expected_grad_rhs)
     with backprop.GradientTape() as tape_diags:
         with backprop.GradientTape() as tape_rhs:
             tape_diags.watch(diags)
             tape_rhs.watch(rhs)
             if test_util.is_xla_enabled():
                 # Pivoting is not supported by xla backends.
                 return
             x = linalg_impl.tridiagonal_solve(
                 diags,
                 rhs,
                 diagonals_format=diags_format,
                 transpose_rhs=transpose_rhs,
                 conjugate_rhs=conjugate_rhs)
             res = math_ops.reduce_sum(x * y)
     with self.cached_session() as sess:
         actual_grad_diags = sess.run(tape_diags.gradient(res, diags),
                                      feed_dict=feed_dict)
         actual_rhs_diags = sess.run(tape_rhs.gradient(res, rhs),
                                     feed_dict=feed_dict)
     self.assertAllClose(expected_grad_diags, actual_grad_diags)
     self.assertAllClose(expected_grad_rhs, actual_rhs_diags)
    def benchmarkTridiagonalSolveOp(self):
      devices = [("/cpu:0", "cpu")]
      if test.is_gpu_available(cuda_only=True):
        devices += [("/gpu:0", "gpu")]

      for device_option, pivoting_option, size_option in \
          itertools.product(devices, self.pivoting_options, self.sizes):

        device_id, device_name = device_option
        pivoting, pivoting_name = pivoting_option
        matrix_size, batch_size, num_rhs = size_option

        with ops.Graph().as_default(), \
            session.Session(config=benchmark.benchmark_config()) as sess, \
            ops.device(device_id):
          diags, rhs = self._generateData(matrix_size, batch_size, num_rhs)
          x = linalg_impl.tridiagonal_solve(
              diags, rhs, partial_pivoting=pivoting)
          variables.global_variables_initializer().run()
          self.run_op_benchmark(
              sess,
              control_flow_ops.group(x),
              min_iters=10,
              store_memory_usage=False,
              name=("tridiagonal_solve_{}_matrix_size_{}_batch_size_{}_"
                    "num_rhs_{}_{}").format(device_name, matrix_size,
                                            batch_size, num_rhs, pivoting_name))
Пример #11
0
  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
    diagonals = self.diagonals
    if adjoint:
      diagonals = self._construct_adjoint_diagonals(diagonals)

    # TODO(b/144860784): Remove the broadcasting code below once
    # tridiagonal_solve broadcasts.

    rhs_shape = array_ops.shape(rhs)
    k = self._shape_tensor(diagonals)[-1]
    broadcast_shape = array_ops.broadcast_dynamic_shape(
        self._shape_tensor(diagonals)[:-2], rhs_shape[:-2])
    rhs = array_ops.broadcast_to(
        rhs, array_ops.concat(
            [broadcast_shape, rhs_shape[-2:]], axis=-1))
    if self.diagonals_format == _MATRIX:
      diagonals = array_ops.broadcast_to(
          diagonals, array_ops.concat(
              [broadcast_shape, [k, k]], axis=-1))
    elif self.diagonals_format == _COMPACT:
      diagonals = array_ops.broadcast_to(
          diagonals, array_ops.concat(
              [broadcast_shape, [3, k]], axis=-1))
    else:
      diagonals = [
          array_ops.broadcast_to(d, array_ops.concat(
              [broadcast_shape, [k]], axis=-1)) for d in diagonals]

    y = linalg.tridiagonal_solve(
        diagonals, rhs,
        diagonals_format=self.diagonals_format,
        transpose_rhs=adjoint_arg,
        conjugate_rhs=adjoint_arg)
    return y
    def benchmarkTridiagonalSolveOp(self):
      devices = [("/cpu:0", "cpu")]
      if test.is_gpu_available(cuda_only=True):
        devices += [("/gpu:0", "gpu")]

      for device_option, pivoting_option, size_option in \
          itertools.product(devices, self.pivoting_options, self.sizes):

        device_id, device_name = device_option
        pivoting, pivoting_name = pivoting_option
        matrix_size, batch_size, num_rhs = size_option

        with ops.Graph().as_default(), \
            session.Session(config=benchmark.benchmark_config()) as sess, \
            ops.device(device_id):
          diags, rhs = self._generateData(matrix_size, batch_size, num_rhs)
          x = linalg_impl.tridiagonal_solve(
              diags, rhs, partial_pivoting=pivoting)
          variables.global_variables_initializer().run()
          self.run_op_benchmark(
              sess,
              control_flow_ops.group(x),
              min_iters=10,
              store_memory_usage=False,
              name=("tridiagonal_solve_{}_matrix_size_{}_batch_size_{}_"
                    "num_rhs_{}_{}").format(device_name, matrix_size,
                                            batch_size, num_rhs, pivoting_name))
 def _gradientTest(
     self,
     diags,
     rhs,
     y,  # output = reduce_sum(y * tridiag_solve(diags, rhs))
     expected_grad_diags,  # expected gradient of output w.r.t. diags
     expected_grad_rhs,  # expected gradient of output w.r.t. rhs
     diags_format="compact",
     transpose_rhs=False,
     conjugate_rhs=False,
     feed_dict=None):
   expected_grad_diags = _tfconst(expected_grad_diags)
   expected_grad_rhs = _tfconst(expected_grad_rhs)
   with backprop.GradientTape() as tape_diags:
     with backprop.GradientTape() as tape_rhs:
       tape_diags.watch(diags)
       tape_rhs.watch(rhs)
       x = linalg_impl.tridiagonal_solve(
           diags,
           rhs,
           diagonals_format=diags_format,
           transpose_rhs=transpose_rhs,
           conjugate_rhs=conjugate_rhs)
       res = math_ops.reduce_sum(x * y)
   with self.cached_session(use_gpu=True) as sess:
     actual_grad_diags = sess.run(
         tape_diags.gradient(res, diags), feed_dict=feed_dict)
     actual_rhs_diags = sess.run(
         tape_rhs.gradient(res, rhs), feed_dict=feed_dict)
   self.assertAllClose(expected_grad_diags, actual_grad_diags)
   self.assertAllClose(expected_grad_rhs, actual_rhs_diags)
Пример #14
0
  def testSequenceFormatWithUnknownDims(self):
    if context.executing_eagerly():
      return
    if test_util.is_xla_enabled() and self.pivoting:
      # Pivoting is not supported by xla backends.
      return
    superdiag = array_ops.placeholder(dtypes.float64, shape=[None])
    diag = array_ops.placeholder(dtypes.float64, shape=[None])
    subdiag = array_ops.placeholder(dtypes.float64, shape=[None])
    rhs = array_ops.placeholder(dtypes.float64, shape=[None])

    x = linalg_impl.tridiagonal_solve((superdiag, diag, subdiag),
                                      rhs,
                                      diagonals_format="sequence",
                                      partial_pivoting=self.pivoting)
    with self.cached_session() as sess:
      result = sess.run(
          x,
          feed_dict={
              subdiag: [20, 1, -1, 1],
              diag: [1, 3, 2, 2],
              superdiag: [2, 1, 4, 20],
              rhs: [1, 2, 3, 4]
          })
      self.assertAllClose(result, [-9, 5, -4, 4])
 def _test(self,
           diags,
           rhs,
           expected,
           diags_format="compact",
           transpose_rhs=False,
           conjugate_rhs=False):
   with self.cached_session(use_gpu=True):
     result = linalg_impl.tridiagonal_solve(diags, rhs, diags_format,
                                            transpose_rhs, conjugate_rhs)
     self.assertAllClose(self.evaluate(result), expected)
 def _test(self,
           diags,
           rhs,
           expected,
           diags_format="compact",
           transpose_rhs=False,
           conjugate_rhs=False):
   with self.cached_session(use_gpu=True):
     result = linalg_impl.tridiagonal_solve(diags, rhs, diags_format,
                                            transpose_rhs, conjugate_rhs)
     self.assertAllClose(self.evaluate(result), expected)
Пример #17
0
 def _testWithPlaceholders(self,
                           diags_shape,
                           rhs_shape,
                           diags_feed,
                           rhs_feed,
                           expected,
                           diags_format="compact"):
     if context.executing_eagerly():
         return
     diags = array_ops.placeholder(dtypes.float64, shape=diags_shape)
     rhs = array_ops.placeholder(dtypes.float64, shape=rhs_shape)
     x = linalg_impl.tridiagonal_solve(diags, rhs, diags_format)
     with self.cached_session(use_gpu=True) as sess:
         result = sess.run(x, feed_dict={diags: diags_feed, rhs: rhs_feed})
         self.assertAllClose(result, expected)
 def _testWithPlaceholders(self,
                           diags_shape,
                           rhs_shape,
                           diags_feed,
                           rhs_feed,
                           expected,
                           diags_format="compact"):
   if context.executing_eagerly():
     return
   diags = array_ops.placeholder(dtypes.float64, shape=diags_shape)
   rhs = array_ops.placeholder(dtypes.float64, shape=rhs_shape)
   x = linalg_impl.tridiagonal_solve(diags, rhs, diags_format)
   with self.cached_session(use_gpu=True) as sess:
     result = sess.run(x, feed_dict={diags: diags_feed, rhs: rhs_feed})
     self.assertAllClose(result, expected)
Пример #19
0
 def _test(self,
           diags,
           rhs,
           expected,
           diags_format="compact",
           transpose_rhs=False):
     with self.session() as sess, self.test_scope():
         self.assertAllClose(
             sess.run(
                 linalg_impl.tridiagonal_solve(_tfconst(diags),
                                               _tfconst(rhs),
                                               diags_format,
                                               transpose_rhs,
                                               conjugate_rhs=False,
                                               partial_pivoting=False)),
             np.asarray(expected, dtype=np.float32))
Пример #20
0
 def _testWithDiagonalLists(self,
                            diags,
                            rhs,
                            expected,
                            diags_format="compact",
                            transpose_rhs=False):
     with self.session() as sess, self.test_scope():
         self.assertAllClose(
             sess.run(
                 linalg_impl.tridiagonal_solve([_tfconst(x) for x in diags],
                                               _tfconst(rhs),
                                               diags_format,
                                               transpose_rhs,
                                               conjugate_rhs=False,
                                               partial_pivoting=False)),
             sess.run(_tfconst(expected)))
Пример #21
0
  def testTridiagonalSolverSolvesKRhs(self):
    np.random.seed(19)

    batch_size = 8
    num_dims = 11
    num_rhs = 5

    diagonals_np = np.random.normal(size=(batch_size, 3,
                                          num_dims)).astype(np.float32)
    rhs_np = np.random.normal(size=(batch_size, num_dims,
                                    num_rhs)).astype(np.float32)

    with self.session() as sess, self.test_scope():
      diags = array_ops.placeholder(
          shape=(batch_size, 3, num_dims), dtype=dtypes.float32)
      rhs = array_ops.placeholder(
          shape=(batch_size, num_dims, num_rhs), dtype=dtypes.float32)
      x_np = sess.run(
          linalg_impl.tridiagonal_solve(diags, rhs, partial_pivoting=False),
          feed_dict={
              diags: diagonals_np,
              rhs: rhs_np
          })

    superdiag_np = diagonals_np[:, 0]
    diag_np = diagonals_np[:, 1]
    subdiag_np = diagonals_np[:, 2]

    for eq in range(num_rhs):
      y = np.zeros((batch_size, num_dims), dtype=np.float32)
      for i in range(num_dims):
        if i == 0:
          y[:, i] = (
              diag_np[:, i] * x_np[:, i, eq] +
              superdiag_np[:, i] * x_np[:, i + 1, eq])
        elif i == num_dims - 1:
          y[:, i] = (
              subdiag_np[:, i] * x_np[:, i - 1, eq] +
              diag_np[:, i] * x_np[:, i, eq])
        else:
          y[:, i] = (
              subdiag_np[:, i] * x_np[:, i - 1, eq] +
              diag_np[:, i] * x_np[:, i, eq] +
              superdiag_np[:, i] * x_np[:, i + 1, eq])

      self.assertAllClose(y, rhs_np[:, :, eq], rtol=1e-4, atol=1e-4)
Пример #22
0
  def _run_test(self, batch_size, num_dims, num_rhs):
    diagonals_np = np.random.normal(size=(batch_size, 3,
                                          num_dims)).astype(np.float32)
    rhs_np = np.random.normal(size=(batch_size, num_dims,
                                    num_rhs)).astype(np.float32)

    with self.session() as sess, self.test_scope():
      diags = array_ops.placeholder(
          shape=(batch_size, 3, num_dims), dtype=dtypes.float32)
      rhs = array_ops.placeholder(
          shape=(batch_size, num_dims, num_rhs), dtype=dtypes.float32)
      x_np = sess.run(
          linalg_impl.tridiagonal_solve(diags, rhs, partial_pivoting=False),
          feed_dict={
              diags: diagonals_np,
              rhs: rhs_np
          })
      self.assertEqual(x_np.shape, (batch_size, num_dims, num_rhs))
Пример #23
0
 def benchmarkTridiagonalSolveOp(self):
     for matrix_size, batch_size, num_rhs in self.sizes:
         with ops.Graph().as_default(), \
                 session.Session(config=benchmark.benchmark_config()) as sess, \
                 ops.device("/cpu:0"):
             diags, rhs = self._generateData(matrix_size, batch_size,
                                             num_rhs)
             x = linalg_impl.tridiagonal_solve(diags,
                                               rhs,
                                               transpose_rhs=True)
             variables.global_variables_initializer().run()
             self.run_op_benchmark(
                 sess,
                 control_flow_ops.group(x),
                 min_iters=10,
                 store_memory_usage=False,
                 name=("tridiagonal_solve_matrix_size_{}_batch_size_{}_"
                       "num_rhs_{}").format(matrix_size, batch_size,
                                            num_rhs))
Пример #24
0
 def _testWithPlaceholders(self,
                           diags_shape,
                           rhs_shape,
                           diags_feed,
                           rhs_feed,
                           expected,
                           diags_format="compact"):
   if context.executing_eagerly():
     return
   diags = array_ops.placeholder(dtypes.float64, shape=diags_shape)
   rhs = array_ops.placeholder(dtypes.float64, shape=rhs_shape)
   if test_util.is_xla_enabled() and self.pivoting:
     # Pivoting is not supported by xla backends.
     return
   x = linalg_impl.tridiagonal_solve(
       diags, rhs, diags_format, partial_pivoting=self.pivoting)
   with self.cached_session() as sess:
     result = sess.run(x, feed_dict={diags: diags_feed, rhs: rhs_feed})
     self.assertAllClose(result, expected)
Пример #25
0
    def testSequenceFormatWithUnknownDims(self):
        if context.executing_eagerly():
            return
        superdiag = array_ops.placeholder(dtypes.float64, shape=[None])
        diag = array_ops.placeholder(dtypes.float64, shape=[None])
        subdiag = array_ops.placeholder(dtypes.float64, shape=[None])
        rhs = array_ops.placeholder(dtypes.float64, shape=[None])

        x = linalg_impl.tridiagonal_solve((superdiag, diag, subdiag),
                                          rhs,
                                          diagonals_format="sequence")
        with self.cached_session(use_gpu=True) as sess:
            result = sess.run(x,
                              feed_dict={
                                  subdiag: [20, 1, -1, 1],
                                  diag: [1, 3, 2, 2],
                                  superdiag: [2, 1, 4, 20],
                                  rhs: [1, 2, 3, 4]
                              })
            self.assertAllClose(result, [-9, 5, -4, 4])
  def testSequenceFormatWithUnknownDims(self):
    if context.executing_eagerly():
      return
    superdiag = array_ops.placeholder(dtypes.float64, shape=[None])
    diag = array_ops.placeholder(dtypes.float64, shape=[None])
    subdiag = array_ops.placeholder(dtypes.float64, shape=[None])
    rhs = array_ops.placeholder(dtypes.float64, shape=[None])

    x = linalg_impl.tridiagonal_solve((superdiag, diag, subdiag),
                                      rhs,
                                      diagonals_format="sequence")
    with self.cached_session(use_gpu=True) as sess:
      result = sess.run(
          x,
          feed_dict={
              subdiag: [20, 1, -1, 1],
              diag: [1, 3, 2, 2],
              superdiag: [2, 1, 4, 20],
              rhs: [1, 2, 3, 4]
          })
      self.assertAllClose(result, [-9, 5, -4, 4])
Пример #27
0
 def _test(self,
           diags,
           rhs,
           expected,
           diags_format="compact",
           transpose_rhs=False,
           conjugate_rhs=False):
     with self.cached_session(use_gpu=True):
         pivoting = True
         if hasattr(self, "pivoting"):
             pivoting = self.pivoting
         if test_util.is_xla_enabled() and pivoting:
             # Pivoting is not supported by xla backends.
             return
         result = linalg_impl.tridiagonal_solve(diags,
                                                rhs,
                                                diags_format,
                                                transpose_rhs,
                                                conjugate_rhs,
                                                partial_pivoting=pivoting)
         self.assertAllClose(self.evaluate(result), expected)
Пример #28
0
 def _assertRaises(self, diags, rhs, diags_format="compact"):
     with self.assertRaises(ValueError):
         linalg_impl.tridiagonal_solve(diags, rhs, diags_format)
 def _assertRaises(self, diags, rhs, diags_format="compact"):
   with self.assertRaises(ValueError):
     linalg_impl.tridiagonal_solve(diags, rhs, diags_format)