コード例 #1
0
  def operator_and_matrix(
      self, build_info, dtype, use_placeholder,
      ensure_self_adjoint_and_pd=False):
    shape = list(build_info.shape)
    reflection_axis = linear_operator_test_util.random_sign_uniform(
        shape[:-1], minval=1., maxval=2., dtype=dtype)
    # Make sure unit norm.
    reflection_axis = reflection_axis / linalg_ops.norm(
        reflection_axis, axis=-1, keepdims=True)

    lin_op_reflection_axis = reflection_axis

    if use_placeholder:
      lin_op_reflection_axis = array_ops.placeholder_with_default(
          reflection_axis, shape=None)

    operator = householder.LinearOperatorHouseholder(lin_op_reflection_axis)

    mat = reflection_axis[..., array_ops.newaxis]
    matrix = -2 * linear_operator_util.matmul_with_broadcast(
        mat, mat, adjoint_b=True)
    matrix = array_ops.matrix_set_diag(
        matrix, 1. + array_ops.matrix_diag_part(matrix))

    return operator, matrix
コード例 #2
0
 def test_tape_safe(self):
     reflection_axis = variables_module.Variable([1., 3., 5., 8.])
     operator = householder.LinearOperatorHouseholder(reflection_axis)
     self.check_tape_safe(
         operator,
         skip_options=[
             # Determinant hard-coded as 1.
             CheckTapeSafeSkipOptions.DETERMINANT,
             CheckTapeSafeSkipOptions.LOG_ABS_DETERMINANT,
             # Trace hard-coded.
             CheckTapeSafeSkipOptions.TRACE,
         ])
コード例 #3
0
 def test_householder_inverse_type(self):
   reflection_axis = [1., 3., 5., 8.]
   operator = householder.LinearOperatorHouseholder(reflection_axis)
   self.assertIsInstance(
       operator.inverse(), householder.LinearOperatorHouseholder)
コード例 #4
0
 def test_scalar_reflection_axis_raises(self):
   with self.assertRaisesRegexp(ValueError, "must have at least 1 dimension"):
     householder.LinearOperatorHouseholder(1.)
コード例 #5
0
 def test_convert_variables_to_tensors(self):
     reflection_axis = variables_module.Variable([1., 3., 5., 8.])
     operator = householder.LinearOperatorHouseholder(reflection_axis)
     with self.cached_session() as sess:
         sess.run([reflection_axis.initializer])
         self.check_convert_variables_to_tensors(operator)