Пример #1
0
    def test_tape_safe(self):
        diagonals = [
            variables_module.Variable([3., 6., 2.]),
            variables_module.Variable([2., 4., 2.]),
            variables_module.Variable([5., 1., 2.])
        ]
        operator = linalg_lib.LinearOperatorTridiag(
            diagonals, diagonals_format='sequence')
        # Skip the diagonal part and trace since this only dependent on the
        # middle variable. We test this below.
        self.check_tape_safe(operator, skip_options=['diag_part', 'trace'])

        diagonals = [[3., 6., 2.],
                     variables_module.Variable([2., 4., 2.]), [5., 1., 2.]]
        operator = linalg_lib.LinearOperatorTridiag(
            diagonals, diagonals_format='sequence')
Пример #2
0
 def test_convert_variables_to_tensors(self):
     diag = variables_module.Variable([[3., 6., 2.], [2., 4., 2.],
                                       [5., 1., 2.]])
     operator = linalg_lib.LinearOperatorTridiag(diag,
                                                 diagonals_format='compact')
     with self.cached_session() as sess:
         sess.run([diag.initializer])
         self.check_convert_variables_to_tensors(operator)
  def build_operator_and_matrix(
      self, build_info, dtype, use_placeholder,
      ensure_self_adjoint_and_pd=False,
      diagonals_format='sequence'):
    shape = list(build_info.shape)

    # Ensure that diagonal has large enough values. If we generate a
    # self adjoint PD matrix, then the diagonal will be dominant guaranteeing
    # positive definitess.
    diag = linear_operator_test_util.random_sign_uniform(
        shape[:-1], minval=4., maxval=6., dtype=dtype)
    # We'll truncate these depending on the format
    subdiag = linear_operator_test_util.random_sign_uniform(
        shape[:-1], minval=1., maxval=2., dtype=dtype)
    if ensure_self_adjoint_and_pd:
      # Abs on complex64 will result in a float32, so we cast back up.
      diag = math_ops.cast(math_ops.abs(diag), dtype=dtype)
      # The first element of subdiag is ignored. We'll add a dummy element
      # to superdiag to pad it.
      superdiag = math_ops.conj(subdiag)
      superdiag = manip_ops.roll(superdiag, shift=-1, axis=-1)
    else:
      superdiag = linear_operator_test_util.random_sign_uniform(
          shape[:-1], minval=1., maxval=2., dtype=dtype)

    matrix_diagonals = array_ops.stack(
        [superdiag, diag, subdiag], axis=-2)
    matrix = gen_array_ops.matrix_diag_v3(
        matrix_diagonals,
        k=(-1, 1),
        num_rows=-1,
        num_cols=-1,
        align='LEFT_RIGHT',
        padding_value=0.)

    if diagonals_format == 'sequence':
      diagonals = [superdiag, diag, subdiag]
    elif diagonals_format == 'compact':
      diagonals = array_ops.stack([superdiag, diag, subdiag], axis=-2)
    elif diagonals_format == 'matrix':
      diagonals = matrix

    lin_op_diagonals = diagonals

    if use_placeholder:
      if diagonals_format == 'sequence':
        lin_op_diagonals = [array_ops.placeholder_with_default(
            d, shape=None) for d in lin_op_diagonals]
      else:
        lin_op_diagonals = array_ops.placeholder_with_default(
            lin_op_diagonals, shape=None)

    operator = linalg_lib.LinearOperatorTridiag(
        diagonals=lin_op_diagonals,
        diagonals_format=diagonals_format,
        is_self_adjoint=True if ensure_self_adjoint_and_pd else None,
        is_positive_definite=True if ensure_self_adjoint_and_pd else None)
    return operator, matrix
Пример #4
0
 def test_tape_safe(self):
     matrix = variables_module.Variable([[3., 2., 0.], [1., 6., 4.],
                                         [0., 2, 2]])
     operator = linalg_lib.LinearOperatorTridiag(matrix,
                                                 diagonals_format='matrix')
     self.check_tape_safe(operator)
Пример #5
0
 def test_tape_safe(self):
     diag = variables_module.Variable([[3., 6., 2.], [2., 4., 2.],
                                       [5., 1., 2.]])
     operator = linalg_lib.LinearOperatorTridiag(diag,
                                                 diagonals_format='compact')
     self.check_tape_safe(operator)