def _assertRaises(self, diags, rhs, diags_format="compact"): pivoting = True if hasattr(self, "pivoting"): pivoting = self.pivoting with self.assertRaises(ValueError): linalg_impl.tridiagonal_solve( diags, rhs, diags_format, partial_pivoting=pivoting)
def _assertRaises(self, diags, rhs, diags_format="compact"): # Skip tests for combinations with missing implementations. _, pivoting, perturb_singular = self._is_unimplemented() with self.assertRaises((NotImplementedError, ValueError)): linalg_impl.tridiagonal_solve(diags, rhs, diags_format, partial_pivoting=pivoting, perturb_singular=perturb_singular)
def _assertRaises(self, diags, rhs, diags_format="compact"): pivoting = True if hasattr(self, "pivoting"): pivoting = self.pivoting if test_util.is_xla_enabled() and pivoting: # Pivoting is not supported by xla backends. return with self.assertRaises(ValueError): linalg_impl.tridiagonal_solve( diags, rhs, diags_format, partial_pivoting=pivoting)
def _test(self, diags, rhs, expected, diags_format="compact", transpose_rhs=False, conjugate_rhs=False): with self.cached_session(): # Skip tests for combinations with missing implementations. unimplemented, pivoting, perturb_singular = self._is_unimplemented( ) if unimplemented: return result = linalg_impl.tridiagonal_solve( diags, rhs, diags_format, transpose_rhs, conjugate_rhs, partial_pivoting=pivoting, perturb_singular=perturb_singular) result = self.evaluate(result) if expected is None: self.assertAllEqual(np.zeros_like(result, dtype=np.bool), np.isfinite(result)) else: self.assertAllClose(result, expected)
def _benchmark(self, generate_data_fn, test_name_format_string): devices = [("/cpu:0", "cpu")] if test.is_gpu_available(cuda_only=True): devices += [("/gpu:0", "gpu")] for device_option, pivoting_option, size_option in \ itertools.product(devices, self.pivoting_options, self.sizes): device_id, device_name = device_option pivoting, pivoting_name = pivoting_option matrix_size, batch_size, num_rhs = size_option with ops.Graph().as_default(), \ session.Session(config=benchmark.benchmark_config()) as sess, \ ops.device(device_id): diags, rhs = generate_data_fn(matrix_size, batch_size, num_rhs) # Pivoting is not supported by XLA backends. if test.is_xla_enabled() and pivoting: return x = linalg_impl.tridiagonal_solve( diags, rhs, partial_pivoting=pivoting) self.evaluate(variables.global_variables_initializer()) self.run_op_benchmark(sess, control_flow_ops.group(x), min_iters=10, store_memory_usage=False, name=test_name_format_string.format( device_name, matrix_size, batch_size, num_rhs, pivoting_name))
def _gradientTest( self, diags, rhs, y, # output = reduce_sum(y * tridiag_solve(diags, rhs)) expected_grad_diags, # expected gradient of output w.r.t. diags expected_grad_rhs, # expected gradient of output w.r.t. rhs diags_format="compact", transpose_rhs=False, feed_dict=None): expected_grad_diags = np.array(expected_grad_diags).astype(np.float32) expected_grad_rhs = np.array(expected_grad_rhs).astype(np.float32) with self.session() as sess, self.test_scope(): diags = _tfconst(diags) rhs = _tfconst(rhs) y = _tfconst(y) x = linalg_impl.tridiagonal_solve(diags, rhs, diagonals_format=diags_format, transpose_rhs=transpose_rhs, conjugate_rhs=False, partial_pivoting=False) res = math_ops.reduce_sum(x * y) actual_grad_diags = sess.run(gradient_ops.gradients(res, diags)[0], feed_dict=feed_dict) actual_rhs_diags = sess.run(gradient_ops.gradients(res, rhs)[0], feed_dict=feed_dict) self.assertAllClose(expected_grad_diags, actual_grad_diags) self.assertAllClose(expected_grad_rhs, actual_rhs_diags)
def testPartialPivotingRaises(self): np.random.seed(0) batch_size = 8 num_dims = 11 num_rhs = 5 diagonals_np = np.random.normal(size=(batch_size, 3, num_dims)).astype(np.float32) rhs_np = np.random.normal(size=(batch_size, num_dims, num_rhs)).astype(np.float32) with self.session() as sess, self.test_scope(): with self.assertRaisesRegex( errors_impl.UnimplementedError, "Current implementation does not yet support pivoting."): diags = array_ops.placeholder(shape=(batch_size, 3, num_dims), dtype=dtypes.float32) rhs = array_ops.placeholder(shape=(batch_size, num_dims, num_rhs), dtype=dtypes.float32) sess.run(linalg_impl.tridiagonal_solve(diags, rhs, partial_pivoting=True), feed_dict={ diags: diagonals_np, rhs: rhs_np })
def _test(self, diags, rhs, expected, diags_format="compact", transpose_rhs=False, conjugate_rhs=False): with self.cached_session(): pivoting = True if hasattr(self, "pivoting"): pivoting = self.pivoting if test_util.is_xla_enabled() and pivoting: # Pivoting is not supported by xla backends. return result = linalg_impl.tridiagonal_solve(diags, rhs, diags_format, transpose_rhs, conjugate_rhs, partial_pivoting=pivoting) result = self.evaluate(result) if expected is None: self.assertAllEqual(np.zeros_like(result, dtype=np.bool), np.isfinite(result)) else: self.assertAllClose(result, expected)
def _gradientTest( self, diags, rhs, y, # output = reduce_sum(y * tridiag_solve(diags, rhs)) expected_grad_diags, # expected gradient of output w.r.t. diags expected_grad_rhs, # expected gradient of output w.r.t. rhs diags_format="compact", transpose_rhs=False, conjugate_rhs=False, feed_dict=None): expected_grad_diags = _tfconst(expected_grad_diags) expected_grad_rhs = _tfconst(expected_grad_rhs) with backprop.GradientTape() as tape_diags: with backprop.GradientTape() as tape_rhs: tape_diags.watch(diags) tape_rhs.watch(rhs) if test_util.is_xla_enabled(): # Pivoting is not supported by xla backends. return x = linalg_impl.tridiagonal_solve( diags, rhs, diagonals_format=diags_format, transpose_rhs=transpose_rhs, conjugate_rhs=conjugate_rhs) res = math_ops.reduce_sum(x * y) with self.cached_session() as sess: actual_grad_diags = sess.run(tape_diags.gradient(res, diags), feed_dict=feed_dict) actual_rhs_diags = sess.run(tape_rhs.gradient(res, rhs), feed_dict=feed_dict) self.assertAllClose(expected_grad_diags, actual_grad_diags) self.assertAllClose(expected_grad_rhs, actual_rhs_diags)
def benchmarkTridiagonalSolveOp(self): devices = [("/cpu:0", "cpu")] if test.is_gpu_available(cuda_only=True): devices += [("/gpu:0", "gpu")] for device_option, pivoting_option, size_option in \ itertools.product(devices, self.pivoting_options, self.sizes): device_id, device_name = device_option pivoting, pivoting_name = pivoting_option matrix_size, batch_size, num_rhs = size_option with ops.Graph().as_default(), \ session.Session(config=benchmark.benchmark_config()) as sess, \ ops.device(device_id): diags, rhs = self._generateData(matrix_size, batch_size, num_rhs) x = linalg_impl.tridiagonal_solve( diags, rhs, partial_pivoting=pivoting) variables.global_variables_initializer().run() self.run_op_benchmark( sess, control_flow_ops.group(x), min_iters=10, store_memory_usage=False, name=("tridiagonal_solve_{}_matrix_size_{}_batch_size_{}_" "num_rhs_{}_{}").format(device_name, matrix_size, batch_size, num_rhs, pivoting_name))
def _solve(self, rhs, adjoint=False, adjoint_arg=False): diagonals = self.diagonals if adjoint: diagonals = self._construct_adjoint_diagonals(diagonals) # TODO(b/144860784): Remove the broadcasting code below once # tridiagonal_solve broadcasts. rhs_shape = array_ops.shape(rhs) k = self._shape_tensor(diagonals)[-1] broadcast_shape = array_ops.broadcast_dynamic_shape( self._shape_tensor(diagonals)[:-2], rhs_shape[:-2]) rhs = array_ops.broadcast_to( rhs, array_ops.concat( [broadcast_shape, rhs_shape[-2:]], axis=-1)) if self.diagonals_format == _MATRIX: diagonals = array_ops.broadcast_to( diagonals, array_ops.concat( [broadcast_shape, [k, k]], axis=-1)) elif self.diagonals_format == _COMPACT: diagonals = array_ops.broadcast_to( diagonals, array_ops.concat( [broadcast_shape, [3, k]], axis=-1)) else: diagonals = [ array_ops.broadcast_to(d, array_ops.concat( [broadcast_shape, [k]], axis=-1)) for d in diagonals] y = linalg.tridiagonal_solve( diagonals, rhs, diagonals_format=self.diagonals_format, transpose_rhs=adjoint_arg, conjugate_rhs=adjoint_arg) return y
def _gradientTest( self, diags, rhs, y, # output = reduce_sum(y * tridiag_solve(diags, rhs)) expected_grad_diags, # expected gradient of output w.r.t. diags expected_grad_rhs, # expected gradient of output w.r.t. rhs diags_format="compact", transpose_rhs=False, conjugate_rhs=False, feed_dict=None): expected_grad_diags = _tfconst(expected_grad_diags) expected_grad_rhs = _tfconst(expected_grad_rhs) with backprop.GradientTape() as tape_diags: with backprop.GradientTape() as tape_rhs: tape_diags.watch(diags) tape_rhs.watch(rhs) x = linalg_impl.tridiagonal_solve( diags, rhs, diagonals_format=diags_format, transpose_rhs=transpose_rhs, conjugate_rhs=conjugate_rhs) res = math_ops.reduce_sum(x * y) with self.cached_session(use_gpu=True) as sess: actual_grad_diags = sess.run( tape_diags.gradient(res, diags), feed_dict=feed_dict) actual_rhs_diags = sess.run( tape_rhs.gradient(res, rhs), feed_dict=feed_dict) self.assertAllClose(expected_grad_diags, actual_grad_diags) self.assertAllClose(expected_grad_rhs, actual_rhs_diags)
def testSequenceFormatWithUnknownDims(self): if context.executing_eagerly(): return if test_util.is_xla_enabled() and self.pivoting: # Pivoting is not supported by xla backends. return superdiag = array_ops.placeholder(dtypes.float64, shape=[None]) diag = array_ops.placeholder(dtypes.float64, shape=[None]) subdiag = array_ops.placeholder(dtypes.float64, shape=[None]) rhs = array_ops.placeholder(dtypes.float64, shape=[None]) x = linalg_impl.tridiagonal_solve((superdiag, diag, subdiag), rhs, diagonals_format="sequence", partial_pivoting=self.pivoting) with self.cached_session() as sess: result = sess.run( x, feed_dict={ subdiag: [20, 1, -1, 1], diag: [1, 3, 2, 2], superdiag: [2, 1, 4, 20], rhs: [1, 2, 3, 4] }) self.assertAllClose(result, [-9, 5, -4, 4])
def _test(self, diags, rhs, expected, diags_format="compact", transpose_rhs=False, conjugate_rhs=False): with self.cached_session(use_gpu=True): result = linalg_impl.tridiagonal_solve(diags, rhs, diags_format, transpose_rhs, conjugate_rhs) self.assertAllClose(self.evaluate(result), expected)
def _testWithPlaceholders(self, diags_shape, rhs_shape, diags_feed, rhs_feed, expected, diags_format="compact"): if context.executing_eagerly(): return diags = array_ops.placeholder(dtypes.float64, shape=diags_shape) rhs = array_ops.placeholder(dtypes.float64, shape=rhs_shape) x = linalg_impl.tridiagonal_solve(diags, rhs, diags_format) with self.cached_session(use_gpu=True) as sess: result = sess.run(x, feed_dict={diags: diags_feed, rhs: rhs_feed}) self.assertAllClose(result, expected)
def _test(self, diags, rhs, expected, diags_format="compact", transpose_rhs=False): with self.session() as sess, self.test_scope(): self.assertAllClose( sess.run( linalg_impl.tridiagonal_solve(_tfconst(diags), _tfconst(rhs), diags_format, transpose_rhs, conjugate_rhs=False, partial_pivoting=False)), np.asarray(expected, dtype=np.float32))
def _testWithDiagonalLists(self, diags, rhs, expected, diags_format="compact", transpose_rhs=False): with self.session() as sess, self.test_scope(): self.assertAllClose( sess.run( linalg_impl.tridiagonal_solve([_tfconst(x) for x in diags], _tfconst(rhs), diags_format, transpose_rhs, conjugate_rhs=False, partial_pivoting=False)), sess.run(_tfconst(expected)))
def testTridiagonalSolverSolvesKRhs(self): np.random.seed(19) batch_size = 8 num_dims = 11 num_rhs = 5 diagonals_np = np.random.normal(size=(batch_size, 3, num_dims)).astype(np.float32) rhs_np = np.random.normal(size=(batch_size, num_dims, num_rhs)).astype(np.float32) with self.session() as sess, self.test_scope(): diags = array_ops.placeholder( shape=(batch_size, 3, num_dims), dtype=dtypes.float32) rhs = array_ops.placeholder( shape=(batch_size, num_dims, num_rhs), dtype=dtypes.float32) x_np = sess.run( linalg_impl.tridiagonal_solve(diags, rhs, partial_pivoting=False), feed_dict={ diags: diagonals_np, rhs: rhs_np }) superdiag_np = diagonals_np[:, 0] diag_np = diagonals_np[:, 1] subdiag_np = diagonals_np[:, 2] for eq in range(num_rhs): y = np.zeros((batch_size, num_dims), dtype=np.float32) for i in range(num_dims): if i == 0: y[:, i] = ( diag_np[:, i] * x_np[:, i, eq] + superdiag_np[:, i] * x_np[:, i + 1, eq]) elif i == num_dims - 1: y[:, i] = ( subdiag_np[:, i] * x_np[:, i - 1, eq] + diag_np[:, i] * x_np[:, i, eq]) else: y[:, i] = ( subdiag_np[:, i] * x_np[:, i - 1, eq] + diag_np[:, i] * x_np[:, i, eq] + superdiag_np[:, i] * x_np[:, i + 1, eq]) self.assertAllClose(y, rhs_np[:, :, eq], rtol=1e-4, atol=1e-4)
def _run_test(self, batch_size, num_dims, num_rhs): diagonals_np = np.random.normal(size=(batch_size, 3, num_dims)).astype(np.float32) rhs_np = np.random.normal(size=(batch_size, num_dims, num_rhs)).astype(np.float32) with self.session() as sess, self.test_scope(): diags = array_ops.placeholder( shape=(batch_size, 3, num_dims), dtype=dtypes.float32) rhs = array_ops.placeholder( shape=(batch_size, num_dims, num_rhs), dtype=dtypes.float32) x_np = sess.run( linalg_impl.tridiagonal_solve(diags, rhs, partial_pivoting=False), feed_dict={ diags: diagonals_np, rhs: rhs_np }) self.assertEqual(x_np.shape, (batch_size, num_dims, num_rhs))
def benchmarkTridiagonalSolveOp(self): for matrix_size, batch_size, num_rhs in self.sizes: with ops.Graph().as_default(), \ session.Session(config=benchmark.benchmark_config()) as sess, \ ops.device("/cpu:0"): diags, rhs = self._generateData(matrix_size, batch_size, num_rhs) x = linalg_impl.tridiagonal_solve(diags, rhs, transpose_rhs=True) variables.global_variables_initializer().run() self.run_op_benchmark( sess, control_flow_ops.group(x), min_iters=10, store_memory_usage=False, name=("tridiagonal_solve_matrix_size_{}_batch_size_{}_" "num_rhs_{}").format(matrix_size, batch_size, num_rhs))
def _testWithPlaceholders(self, diags_shape, rhs_shape, diags_feed, rhs_feed, expected, diags_format="compact"): if context.executing_eagerly(): return diags = array_ops.placeholder(dtypes.float64, shape=diags_shape) rhs = array_ops.placeholder(dtypes.float64, shape=rhs_shape) if test_util.is_xla_enabled() and self.pivoting: # Pivoting is not supported by xla backends. return x = linalg_impl.tridiagonal_solve( diags, rhs, diags_format, partial_pivoting=self.pivoting) with self.cached_session() as sess: result = sess.run(x, feed_dict={diags: diags_feed, rhs: rhs_feed}) self.assertAllClose(result, expected)
def testSequenceFormatWithUnknownDims(self): if context.executing_eagerly(): return superdiag = array_ops.placeholder(dtypes.float64, shape=[None]) diag = array_ops.placeholder(dtypes.float64, shape=[None]) subdiag = array_ops.placeholder(dtypes.float64, shape=[None]) rhs = array_ops.placeholder(dtypes.float64, shape=[None]) x = linalg_impl.tridiagonal_solve((superdiag, diag, subdiag), rhs, diagonals_format="sequence") with self.cached_session(use_gpu=True) as sess: result = sess.run(x, feed_dict={ subdiag: [20, 1, -1, 1], diag: [1, 3, 2, 2], superdiag: [2, 1, 4, 20], rhs: [1, 2, 3, 4] }) self.assertAllClose(result, [-9, 5, -4, 4])
def testSequenceFormatWithUnknownDims(self): if context.executing_eagerly(): return superdiag = array_ops.placeholder(dtypes.float64, shape=[None]) diag = array_ops.placeholder(dtypes.float64, shape=[None]) subdiag = array_ops.placeholder(dtypes.float64, shape=[None]) rhs = array_ops.placeholder(dtypes.float64, shape=[None]) x = linalg_impl.tridiagonal_solve((superdiag, diag, subdiag), rhs, diagonals_format="sequence") with self.cached_session(use_gpu=True) as sess: result = sess.run( x, feed_dict={ subdiag: [20, 1, -1, 1], diag: [1, 3, 2, 2], superdiag: [2, 1, 4, 20], rhs: [1, 2, 3, 4] }) self.assertAllClose(result, [-9, 5, -4, 4])
def _test(self, diags, rhs, expected, diags_format="compact", transpose_rhs=False, conjugate_rhs=False): with self.cached_session(use_gpu=True): pivoting = True if hasattr(self, "pivoting"): pivoting = self.pivoting if test_util.is_xla_enabled() and pivoting: # Pivoting is not supported by xla backends. return result = linalg_impl.tridiagonal_solve(diags, rhs, diags_format, transpose_rhs, conjugate_rhs, partial_pivoting=pivoting) self.assertAllClose(self.evaluate(result), expected)
def _assertRaises(self, diags, rhs, diags_format="compact"): with self.assertRaises(ValueError): linalg_impl.tridiagonal_solve(diags, rhs, diags_format)