Esempio n. 1
0
def _QrGrad(op, dq, dr):
  """Gradient for Qr."""
  q, r = op.outputs
  if q.dtype.is_complex:
    raise NotImplementedError("QrGrad not implemented for dtype: %s" % q.dtype)
  if (r.shape.ndims is None or r.shape.as_list()[-2] is None or
      r.shape.as_list()[-1] is None):
    raise NotImplementedError("QrGrad not implemented with dynamic shapes.")
  if r.shape.dims[-2].value != r.shape.dims[-1].value:
    raise NotImplementedError("QrGrad not implemented when ncols > nrows "
                              "or full_matrices is true and ncols != nrows.")

  qdq = math_ops.matmul(q, dq, adjoint_a=True)
  qdq_ = qdq - _linalg.adjoint(qdq)
  rdr = math_ops.matmul(r, dr, adjoint_b=True)
  rdr_ = rdr - _linalg.adjoint(rdr)
  tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0)

  def _TriangularSolve(x, r):
    """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri."""
    return _linalg.adjoint(
        linalg_ops.matrix_triangular_solve(
            r, _linalg.adjoint(x), lower=False, adjoint=False))

  grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r))
  grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r)
  return grad_a + grad_b
 def _test_solve(self, with_batch):
   for use_placeholder in self._use_placeholder_options:
     for build_info in self._operator_build_infos:
       for dtype in self._dtypes_to_test:
         for adjoint in self._adjoint_options:
           for adjoint_arg in self._adjoint_arg_options:
             with self.test_session(graph=ops.Graph()) as sess:
               sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
               operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
                   build_info, dtype, use_placeholder=use_placeholder)
               rhs = self._make_rhs(
                   operator, adjoint=adjoint, with_batch=with_batch)
               # If adjoint_arg, solve A X = (rhs^H)^H = rhs.
               if adjoint_arg:
                 op_solve = operator.solve(
                     linalg.adjoint(rhs),
                     adjoint=adjoint,
                     adjoint_arg=adjoint_arg)
               else:
                 op_solve = operator.solve(
                     rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
               mat_solve = linear_operator_util.matrix_solve_with_broadcast(
                   mat, rhs, adjoint=adjoint)
               if not use_placeholder:
                 self.assertAllEqual(op_solve.get_shape(),
                                     mat_solve.get_shape())
               op_solve_v, mat_solve_v = sess.run(
                   [op_solve, mat_solve], feed_dict=feed_dict)
               self.assertAC(op_solve_v, mat_solve_v)
def _test_matmul_base(
    self,
    use_placeholder,
    shapes_info,
    dtype,
    adjoint,
    adjoint_arg,
    with_batch):
  # If batch dimensions are omitted, but there are
  # no batch dimensions for the linear operator, then
  # skip the test case. This is already checked with
  # with_batch=True.
  if not with_batch and len(shapes_info.shape) <= 2:
    return
  with self.session(graph=ops.Graph()) as sess:
    sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
    operator, mat = self.operator_and_matrix(
        shapes_info, dtype, use_placeholder=use_placeholder)
    x = self.make_x(
        operator, adjoint=adjoint, with_batch=with_batch)
    # If adjoint_arg, compute A X^H^H = A X.
    if adjoint_arg:
      op_matmul = operator.matmul(
          linalg.adjoint(x),
          adjoint=adjoint,
          adjoint_arg=adjoint_arg)
    else:
      op_matmul = operator.matmul(x, adjoint=adjoint)
    mat_matmul = math_ops.matmul(mat, x, adjoint_a=adjoint)
    if not use_placeholder:
      self.assertAllEqual(op_matmul.get_shape(),
                          mat_matmul.get_shape())
    op_matmul_v, mat_matmul_v = sess.run(
        [op_matmul, mat_matmul])
    self.assertAC(op_matmul_v, mat_matmul_v)
 def _matmul(self, x, adjoint=False, adjoint_arg=False):
   # Note that adjoint has no effect since this matrix is self-adjoint.
   x = linalg.adjoint(x) if adjoint_arg else x
   if self._assert_proper_shapes:
     aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x)
     x = control_flow_ops.with_dependencies([aps], x)
   return self._possibly_broadcast_batch_shape(x)
 def _test_matmul(self, with_batch):
   for use_placeholder in self._use_placeholder_options:
     for build_info in self._operator_build_infos:
       for dtype in self._dtypes_to_test:
         for adjoint in self._adjoint_options:
           for adjoint_arg in self._adjoint_arg_options:
             with self.test_session(graph=ops.Graph()) as sess:
               sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
               operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
                   build_info, dtype, use_placeholder=use_placeholder)
               x = self._make_x(
                   operator, adjoint=adjoint, with_batch=with_batch)
               # If adjoint_arg, compute A X^H^H = A X.
               if adjoint_arg:
                 op_matmul = operator.matmul(
                     linalg.adjoint(x),
                     adjoint=adjoint,
                     adjoint_arg=adjoint_arg)
               else:
                 op_matmul = operator.matmul(x, adjoint=adjoint)
               mat_matmul = linear_operator_util.matmul_with_broadcast(
                   mat, x, adjoint_a=adjoint)
               if not use_placeholder:
                 self.assertAllEqual(op_matmul.get_shape(),
                                     mat_matmul.get_shape())
               op_matmul_v, mat_matmul_v = sess.run(
                   [op_matmul, mat_matmul], feed_dict=feed_dict)
               self.assertAC(op_matmul_v, mat_matmul_v)
 def _test_solve(self, with_batch):
   for use_placeholder in self._use_placeholder_options:
     for build_info in self._operator_build_infos:
       # If batch dimensions are omitted, but there are
       # no batch dimensions for the linear operator, then
       # skip the test case. This is already checked with
       # with_batch=True.
       if not with_batch and len(build_info.shape) <= 2:
         continue
       for dtype in self._dtypes_to_test:
         for adjoint in self._adjoint_options:
           for adjoint_arg in self._adjoint_arg_options:
             with self.test_session(graph=ops.Graph()) as sess:
               sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
               operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
                   build_info, dtype, use_placeholder=use_placeholder)
               rhs = self._make_rhs(
                   operator, adjoint=adjoint, with_batch=with_batch)
               # If adjoint_arg, solve A X = (rhs^H)^H = rhs.
               if adjoint_arg:
                 op_solve = operator.solve(
                     linalg.adjoint(rhs),
                     adjoint=adjoint,
                     adjoint_arg=adjoint_arg)
               else:
                 op_solve = operator.solve(
                     rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
               mat_solve = linear_operator_util.matrix_solve_with_broadcast(
                   mat, rhs, adjoint=adjoint)
               if not use_placeholder:
                 self.assertAllEqual(op_solve.get_shape(),
                                     mat_solve.get_shape())
               op_solve_v, mat_solve_v = sess.run(
                   [op_solve, mat_solve], feed_dict=feed_dict)
               self.assertAC(op_solve_v, mat_solve_v)
 def _test_matmul(self, with_batch):
   for use_placeholder in self._use_placeholder_options:
     for build_info in self._operator_build_infos:
       # If batch dimensions are omitted, but there are
       # no batch dimensions for the linear operator, then
       # skip the test case. This is already checked with
       # with_batch=True.
       if not with_batch and len(build_info.shape) <= 2:
         continue
       for dtype in self._dtypes_to_test:
         for adjoint in self._adjoint_options:
           for adjoint_arg in self._adjoint_arg_options:
             with self.session(graph=ops.Graph()) as sess:
               sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
               operator, mat = self._operator_and_matrix(
                   build_info, dtype, use_placeholder=use_placeholder)
               x = self._make_x(
                   operator, adjoint=adjoint, with_batch=with_batch)
               # If adjoint_arg, compute A X^H^H = A X.
               if adjoint_arg:
                 op_matmul = operator.matmul(
                     linalg.adjoint(x),
                     adjoint=adjoint,
                     adjoint_arg=adjoint_arg)
               else:
                 op_matmul = operator.matmul(x, adjoint=adjoint)
               mat_matmul = linear_operator_util.matmul_with_broadcast(
                   mat, x, adjoint_a=adjoint)
               if not use_placeholder:
                 self.assertAllEqual(op_matmul.get_shape(),
                                     mat_matmul.get_shape())
               op_matmul_v, mat_matmul_v = sess.run(
                   [op_matmul, mat_matmul])
               self.assertAC(op_matmul_v, mat_matmul_v)
 def test_solve(self):
   self._skip_if_tests_to_skip_contains("solve")
   for use_placeholder in self._use_placeholder_options:
     for shape in self._shapes_to_test:
       for dtype in self._dtypes_to_test:
         for adjoint in self._adjoint_options:
           for adjoint_arg in self._adjoint_arg_options:
             with self.test_session(graph=ops.Graph()) as sess:
               sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
               operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
                   shape, dtype, use_placeholder=use_placeholder)
               rhs = self._make_rhs(operator, adjoint=adjoint)
               # If adjoint_arg, solve A X = (rhs^H)^H = rhs.
               if adjoint_arg:
                 op_solve = operator.solve(
                     linalg.adjoint(rhs),
                     adjoint=adjoint,
                     adjoint_arg=adjoint_arg)
               else:
                 op_solve = operator.solve(
                     rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
               mat_solve = linalg_ops.matrix_solve(mat, rhs, adjoint=adjoint)
               if not use_placeholder:
                 self.assertAllEqual(op_solve.get_shape(),
                                     mat_solve.get_shape())
               op_solve_v, mat_solve_v = sess.run(
                   [op_solve, mat_solve], feed_dict=feed_dict)
               self.assertAC(op_solve_v, mat_solve_v)
 def test_matmul(self):
   self._skip_if_tests_to_skip_contains("matmul")
   for use_placeholder in self._use_placeholder_options:
     for shape in self._shapes_to_test:
       for dtype in self._dtypes_to_test:
         for adjoint in self._adjoint_options:
           for adjoint_arg in self._adjoint_arg_options:
             with self.test_session(graph=ops.Graph()) as sess:
               sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
               operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
                   shape, dtype, use_placeholder=use_placeholder)
               x = self._make_x(operator, adjoint=adjoint)
               # If adjoint_arg, compute A X^H^H = A X.
               if adjoint_arg:
                 op_matmul = operator.matmul(
                     linalg.adjoint(x),
                     adjoint=adjoint,
                     adjoint_arg=adjoint_arg)
               else:
                 op_matmul = operator.matmul(x, adjoint=adjoint)
               mat_matmul = math_ops.matmul(mat, x, adjoint_a=adjoint)
               if not use_placeholder:
                 self.assertAllEqual(op_matmul.get_shape(),
                                     mat_matmul.get_shape())
               op_matmul_v, mat_matmul_v = sess.run(
                   [op_matmul, mat_matmul], feed_dict=feed_dict)
               self.assertAC(op_matmul_v, mat_matmul_v)
  def _matmul(self, x, adjoint=False, adjoint_arg=False):
    if self._assert_proper_shapes:
      x = linalg.adjoint(x) if adjoint_arg else x
      aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x)
      x = control_flow_ops.with_dependencies([aps], x)
    if self.is_square:
      # Note that adjoint has no effect since this matrix is self-adjoint.
      if adjoint_arg:
        output_shape = array_ops.concat([
            array_ops.shape(x)[:-2],
            [array_ops.shape(x)[-1], array_ops.shape(x)[-2]]], axis=0)
      else:
        output_shape = array_ops.shape(x)

      return self._possibly_broadcast_batch_shape(
          array_ops.zeros(shape=output_shape, dtype=x.dtype))

    x_shape = array_ops.shape(x)
    n = self._num_columns if adjoint else self._num_rows
    m = x_shape[-2] if adjoint_arg else x_shape[-1]

    output_shape = array_ops.concat([x_shape[:-2], [n, m]], axis=0)

    zeros = array_ops.zeros(shape=output_shape, dtype=x.dtype)
    return self._possibly_broadcast_batch_shape(zeros)
Esempio n. 11
0
 def _assert_self_adjoint(self):
   dense = self._get_cached_dense_matrix()
   logging.warn(
       "Using (possibly slow) default implementation of assert_self_adjoint."
       "  Requires conversion to a dense matrix.")
   return check_ops.assert_equal(
       dense,
       linalg.adjoint(dense),
       message="Matrix was not equal to its adjoint.")
 def _matmul(self, x, adjoint=False, adjoint_arg=False):
   x = linalg.adjoint(x) if adjoint_arg else x
   if adjoint:
     matrix = self._multiplier_matrix_conj
   else:
     matrix = self._multiplier_matrix
   if self._assert_proper_shapes:
     aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x)
     x = control_flow_ops.with_dependencies([aps], x)
   return x * matrix
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
   rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
   if adjoint:
     matrix = self._multiplier_matrix_conj
   else:
     matrix = self._multiplier_matrix
   if self._assert_proper_shapes:
     aps = linear_operator_util.assert_compatible_matrix_dimensions(self, rhs)
     rhs = control_flow_ops.with_dependencies([aps], rhs)
   return rhs / matrix
  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
    rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
    spectrum = self._conj_spectrum if adjoint else self._spectrum_complex

    rhs, spectrum = self._broadcast_batch_dims(rhs, spectrum)

    rhs_vb = self._vectorize_then_blockify(rhs)
    fft_rhs_vb = self._fft(rhs_vb)
    solution_vb = self._ifft(fft_rhs_vb / spectrum)
    x = self._unblockify_then_matricize(solution_vb)
    return math_ops.cast(x, self.dtype)
 def test_adjoint(self):
   with self.test_session(graph=ops.Graph()) as sess:
     sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
     operator, mat = self.operator_and_matrix(
         shapes_info, dtype, use_placeholder=use_placeholder)
     op_adjoint = operator.adjoint().to_dense()
     op_adjoint_h = operator.H.to_dense()
     mat_adjoint = linalg.adjoint(mat)
     op_adjoint_v, op_adjoint_h_v, mat_adjoint_v = sess.run(
         [op_adjoint, op_adjoint_h, mat_adjoint])
     self.assertAC(mat_adjoint_v, op_adjoint_v)
     self.assertAC(mat_adjoint_v, op_adjoint_h_v)
Esempio n. 16
0
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
   """Default implementation of _solve."""
   if self.is_square is False:
     raise NotImplementedError(
         "Solve is not yet implemented for non-square operators.")
   logging.warn(
       "Using (possibly slow) default implementation of solve."
       "  Requires conversion to a dense matrix and O(N^3) operations.")
   rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
   if self._can_use_cholesky():
     return linalg_ops.cholesky_solve(self._get_cached_chol(), rhs)
   return linalg_ops.matrix_solve(
       self._get_cached_dense_matrix(), rhs, adjoint=adjoint)
 def test_adjoint(self):
   self._skip_if_tests_to_skip_contains("adjoint")
   for use_placeholder in self._use_placeholder_options:
     for build_info in self._operator_build_infos:
       for dtype in self._dtypes_to_test:
         with self.test_session(graph=ops.Graph()) as sess:
           sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
           operator, mat = self._operator_and_matrix(
               build_info, dtype, use_placeholder=use_placeholder)
           op_adjoint = operator.adjoint().to_dense()
           op_adjoint_h = operator.H.to_dense()
           mat_adjoint = linalg.adjoint(mat)
           op_adjoint_v, op_adjoint_h_v, mat_adjoint_v = sess.run(
               [op_adjoint, op_adjoint_h, mat_adjoint])
           self.assertAC(mat_adjoint_v, op_adjoint_v)
           self.assertAC(mat_adjoint_v, op_adjoint_h_v)
  def _matmul(self, x, adjoint=False, adjoint_arg=False):
    x = linalg.adjoint(x) if adjoint_arg else x
    # With F the matrix of a DFT, and F^{-1}, F^H the inverse and Hermitian
    # transpose, one can show that F^{-1} = F^{H} is the IDFT matrix.  Therefore
    # matmul(x) = F^{-1} diag(spectrum) F x,
    #           = F^{H} diag(spectrum) F x,
    # so that
    # matmul(x, adjoint=True) = F^{H} diag(conj(spectrum)) F x.
    spectrum = self._conj_spectrum if adjoint else self._spectrum_complex

    x, spectrum = self._broadcast_batch_dims(x, spectrum)

    x_vb = self._vectorize_then_blockify(x)
    fft_x_vb = self._fft(x_vb)
    block_vector_result = self._ifft(spectrum * fft_x_vb)
    y = self._unblockify_then_matricize(block_vector_result)

    return math_ops.cast(y, self.dtype)
  def _matmul(self, x, adjoint=False, adjoint_arg=False):
    # Given a vector `v`, we would like to reflect `x` about the hyperplane
    # orthogonal to `v` going through the origin.  We first project `x` to `v`
    # to get v * dot(v, x) / dot(v, v).  After we project, we can reflect the
    # projection about the hyperplane by flipping sign to get
    # -v * dot(v, x) / dot(v, v).  Finally, we can add back the component
    # that is orthogonal to v. This is invariant under reflection, since the
    # whole hyperplane is invariant. This component is equal to x - v * dot(v,
    # x) / dot(v, v), giving the formula x - 2 * v * dot(v, x) / dot(v, v)
    # for the reflection.

    # Note that because this is a reflection, it lies in O(n) (for real vector
    # spaces) or U(n) (for complex vector spaces), and thus is its own adjoint.
    x = linalg.adjoint(x) if adjoint_arg else x
    normalized_axis = self.reflection_axis / linalg.norm(
        self.reflection_axis, axis=-1, keepdims=True)
    mat = normalized_axis[..., array_ops.newaxis]
    x_dot_normalized_v = math_ops.matmul(mat, x, adjoint_a=True)

    return x - 2 * mat * x_dot_normalized_v
Esempio n. 20
0
def _SelfAdjointEigV2Grad(op, grad_e, grad_v):
  """Gradient for SelfAdjointEigV2."""
  e = op.outputs[0]
  compute_v = op.get_attr("compute_v")
  # a = op.inputs[0], which satisfies
  # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i]
  with ops.control_dependencies([grad_e, grad_v]):
    if compute_v:
      v = op.outputs[1]
      # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
      # Notice that because of the term involving f, the gradient becomes
      # infinite (or NaN in practice) when eigenvalues are not unique.
      # Mathematically this should not be surprising, since for (k-fold)
      # degenerate eigenvalues, the corresponding eigenvectors are only defined
      # up to arbitrary rotation in a (k-dimensional) subspace.
      f = array_ops.matrix_set_diag(
          math_ops.reciprocal(
              array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)),
          array_ops.zeros_like(e))
      grad_a = math_ops.matmul(
          v,
          math_ops.matmul(
              array_ops.matrix_diag(grad_e) +
              f * math_ops.matmul(v, grad_v, adjoint_a=True),
              v,
              adjoint_b=True))
    else:
      _, v = linalg_ops.self_adjoint_eig(op.inputs[0])
      grad_a = math_ops.matmul(v,
                               math_ops.matmul(
                                   array_ops.matrix_diag(grad_e),
                                   v,
                                   adjoint_b=True))
    # The forward op only depends on the lower triangular part of a, so here we
    # symmetrize and take the lower triangle
    grad_a = array_ops.matrix_band_part(grad_a + _linalg.adjoint(grad_a), -1, 0)
    grad_a = array_ops.matrix_set_diag(grad_a,
                                       0.5 * array_ops.matrix_diag_part(grad_a))
    return grad_a
Esempio n. 21
0
def _CholeskyGrad(op, grad):
  """Gradient for Cholesky."""

  # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1}
  l = op.outputs[0]
  num_rows = array_ops.shape(l)[-1]
  batch_shape = array_ops.shape(l)[:-2]
  l_inverse = linalg_ops.matrix_triangular_solve(l,
                                                 linalg_ops.eye(
                                                     num_rows,
                                                     batch_shape=batch_shape,
                                                     dtype=l.dtype))

  middle = math_ops.matmul(l, grad, adjoint_a=True)
  middle = array_ops.matrix_set_diag(middle,
                                     0.5 * array_ops.matrix_diag_part(middle))
  middle = array_ops.matrix_band_part(middle, -1, 0)

  grad_a = math_ops.matmul(
      math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse)

  grad_a += _linalg.adjoint(grad_a)
  return grad_a * 0.5
Esempio n. 22
0
def _matmul(  # pylint:disable=missing-docstring
    a,
    b,
    transpose_a=False,
    transpose_b=False,
    adjoint_a=False,
    adjoint_b=False,
    a_is_sparse=False,
    b_is_sparse=False,
    name=None):
  if transpose_a or transpose_b:
    raise ValueError("Transposing not supported at this time.")
  if a_is_sparse or b_is_sparse:
    raise ValueError("Sparse methods not supported at this time.")
  if not isinstance(a, LinearOperator):
    # We use the identity (B^HA^H)^H =  AB
    adjoint_matmul = b.matmul(
        a,
        adjoint=(not adjoint_b),
        adjoint_arg=(not adjoint_a),
        name=name)
    return linalg.adjoint(adjoint_matmul)
  return a.matmul(
      b, adjoint=adjoint_a, adjoint_arg=adjoint_b, name=name)
def _test_solve_base(
    self,
    use_placeholder,
    shapes_info,
    dtype,
    adjoint,
    adjoint_arg,
    with_batch):
  # If batch dimensions are omitted, but there are
  # no batch dimensions for the linear operator, then
  # skip the test case. This is already checked with
  # with_batch=True.
  if not with_batch and len(shapes_info.shape) <= 2:
    return
  with self.session(graph=ops.Graph()) as sess:
    sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
    operator, mat = self.operator_and_matrix(
        shapes_info, dtype, use_placeholder=use_placeholder)
    rhs = self.make_rhs(
        operator, adjoint=adjoint, with_batch=with_batch)
    # If adjoint_arg, solve A X = (rhs^H)^H = rhs.
    if adjoint_arg:
      op_solve = operator.solve(
          linalg.adjoint(rhs),
          adjoint=adjoint,
          adjoint_arg=adjoint_arg)
    else:
      op_solve = operator.solve(
          rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
    mat_solve = linear_operator_util.matrix_solve_with_broadcast(
        mat, rhs, adjoint=adjoint)
    if not use_placeholder:
      self.assertAllEqual(op_solve.get_shape(),
                          mat_solve.get_shape())
    op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve])
    self.assertAC(op_solve_v, mat_solve_v)
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
   rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
   return linear_operator_util.matrix_triangular_solve_with_broadcast(
       self._tril, rhs, lower=True, adjoint=adjoint)
Esempio n. 25
0
 def _TriangularSolve(x, r):
   """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri."""
   return _linalg.adjoint(
       linalg_ops.matrix_triangular_solve(
           r, _linalg.adjoint(x), lower=False, adjoint=False))
Esempio n. 26
0
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
   rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
   return linalg.triangular_solve(
       self._get_tril(), rhs, lower=True, adjoint=adjoint)
Esempio n. 27
0
def _reshape_for_efficiency(a,
                            b,
                            transpose_a=False,
                            transpose_b=False,
                            adjoint_a=False,
                            adjoint_b=False):
  """Maybe reshape a, b, and return an inverse map.  For matmul/solve."""
  def identity(x):
    return x

  # At this point, we have not taken transpose/adjoint of a/b.
  still_need_to_transpose = True

  if a.shape.ndims is None or b.shape.ndims is None:
    return a, b, identity, still_need_to_transpose

  # This could be handled in the future, but seems less common.
  if a.shape.ndims >= b.shape.ndims:
    return a, b, identity, still_need_to_transpose

  # From now on, we might modify b, but will not modify a.

  # Suppose:
  #   a.shape =     C + [m, n], b.shape =
  #   b.shape = S + C + [n, r]
  b_extra_ndims = b.shape.ndims - a.shape.ndims

  # b_extra_sh = S, b_main_sh = C + [n, r]
  b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
  b_main_sh = array_ops.shape(b)[b_extra_ndims:]

  # No reason to flip unless the extra dims of b are big enough.  Why?
  # Assume adjoint/transpose = False.  Then...
  # By not flipping, we have to replicate a to shape
  #   b_extra_sh + a.shape,
  # which could use extra memory.  But in all cases, the final output has shape
  #   b_extra_sh + a.shape[:-1] + [b.shape[-1]]
  # So we only end up creating a larger object if the end dim of b is smaller
  # than the end dim of a.  This often happens, e.g. if b was a vector that was
  # expanded to a matrix (by appending a singleton).

  # Since adjoint/transpose may not be False, we must make adjustments here.
  # The dim of b that holds the multiple equations.
  a_domain_sz_ = a.shape[-2 if adjoint_a or transpose_a else -1]
  b_eq_sz_ = b.shape[-2 if adjoint_b or transpose_b else -1]
  b_extra_sz_ = (
      np.prod(b.shape[:b_extra_ndims].as_list())
      if b.shape[:b_extra_ndims].is_fully_defined() else None)
  if (a_domain_sz_ is not None and b_eq_sz_ is not None and
      b_extra_sz_ is not None):
    if b_extra_sz_ < 2 or a_domain_sz_ <= b_eq_sz_:
      return a, b, identity, still_need_to_transpose

  # At this point, we're flipping for sure!
  # Any transposes/adjoints will happen here explicitly, rather than in calling
  # code.  Why?  To avoid having to write separate complex code for each case.
  if adjoint_a:
    a = linalg.adjoint(a)
  elif transpose_a:
    a = linalg.transpose(a)
  if adjoint_b:
    b = linalg.adjoint(b)
  elif transpose_b:
    b = linalg.transpose(b)
  still_need_to_transpose = False

  # Recompute shapes, since the transpose/adjoint may have changed them.
  b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
  b_main_sh = array_ops.shape(b)[b_extra_ndims:]

  # Permutation to put the extra dims at the end.
  perm = (
      np.concatenate(
          (np.arange(b_extra_ndims, b.shape.ndims),
           np.arange(0, b_extra_ndims)), 0))
  b_extra_on_end = array_ops.transpose(b, perm=perm)

  # Now squash this end into one long dim.
  b_squashed_end = array_ops.reshape(
      b_extra_on_end, array_ops.concat((b_main_sh[:-1], [-1]), 0))

  def reshape_inv(y):
    # Expand the extra dims hanging off the end, "b_extra_sh".
    # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y
    # Could have different batch dims than a and b, because of broadcasting.
    y_extra_shape = array_ops.concat(
        (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0)
    y_extra_on_end = array_ops.reshape(y, y_extra_shape)
    inverse_perm = np.argsort(perm)
    return array_ops.transpose(y_extra_on_end, perm=inverse_perm)

  return a, b_squashed_end, reshape_inv, still_need_to_transpose
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
     rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
     return linalg_ops.matrix_triangular_solve(self._tril,
                                               rhs,
                                               lower=True,
                                               adjoint=adjoint)
Esempio n. 29
0
def _SvdGrad(op, grad_s, grad_u, grad_v):
  """Gradient for the singular value decomposition."""

  # The derivation for the compute_uv=False case, and most of
  # the derivation for the full_matrices=True case, are in
  # Giles' paper (see reference at top of file).  A derivation for
  # the full_matrices=False case is available at
  # https://j-towns.github.io/papers/svd-derivative.pdf
  a = op.inputs[0]
  a_shape = a.get_shape().with_rank_at_least(2)
  grad_s_mat = array_ops.matrix_diag(grad_s)

  if not op.get_attr("compute_uv"):
    s, u, v = linalg_ops.svd(a, compute_uv=True)
    grad_a = math_ops.matmul(u, math_ops.matmul(grad_s_mat, v, adjoint_b=True))
    grad_a.set_shape(a_shape)
    return grad_a

  full_matrices = op.get_attr("full_matrices")

  # TODO(rmlarsen): Make this work with complex types.
  if a.dtype.is_complex:
    raise NotImplementedError(
        "SVD gradient is not implemented for complex types and "
        "compute_uv=True.")
  grad_u_shape = grad_u.get_shape().with_rank_at_least(2)
  grad_v_shape = grad_v.get_shape().with_rank_at_least(2)
  m = a_shape.dims[-2].merge_with(grad_u_shape[-2])
  n = a_shape.dims[-1].merge_with(grad_v_shape[-2])
  batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with(
      grad_v_shape[:-2])
  a_shape = batch_shape.concatenate([m, n])

  m = a_shape.dims[-2].value
  n = a_shape.dims[-1].value
  # TODO(rmlarsen): Make this work with placeholders.
  if m is None or n is None:
    raise NotImplementedError(
        "SVD gradient has not been implemented for input with unknown "
        "inner matrix shape.")

  s = op.outputs[0]
  u = op.outputs[1]
  v = op.outputs[2]

  use_adjoint = False
  if m > n:
    # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the
    # Hermitian transpose of the gradient at the end.
    use_adjoint = True
    m, n = n, m
    u, v = v, u
    grad_u, grad_v = grad_v, grad_u

  with ops.control_dependencies([grad_s, grad_u, grad_v]):
    if full_matrices and abs(m - n) > 1:
      raise NotImplementedError(
          "svd gradient is not implemented for abs(m - n) > 1 "
          "when full_matrices is True")
    s_mat = array_ops.matrix_diag(s)
    s2 = math_ops.square(s)

    # NOTICE: Because of the term involving f, the gradient becomes
    # infinite (or NaN in practice) when singular values are not unique.
    # Mathematically this should not be surprising, since for (k-fold)
    # degenerate singular values, the corresponding singular vectors are
    # only defined up a (k-dimensional) subspace. In practice, this can
    # lead to numerical instability when singular values are close but not
    # exactly equal.
    f = array_ops.matrix_set_diag(
        math_ops.reciprocal(
            array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)),
        array_ops.zeros_like(s))
    s_inv_mat = array_ops.matrix_diag(math_ops.reciprocal(s))

    v1 = v[..., :, :m]
    grad_v1 = grad_v[..., :, :m]

    u_gu = math_ops.matmul(u, grad_u, adjoint_a=True)
    v_gv = math_ops.matmul(v1, grad_v1, adjoint_a=True)

    f_u = f * u_gu
    f_v = f * v_gv

    term1_nouv = (
        grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) +
        math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v)))

    term1 = math_ops.matmul(u, math_ops.matmul(term1_nouv, v1, adjoint_b=True))

    if m == n:
      grad_a_before_transpose = term1
    else:
      gv1t = array_ops.matrix_transpose(grad_v1)
      gv1t_v1 = math_ops.matmul(gv1t, v1)
      term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True)

      if full_matrices:
        v2 = v[..., :, m:n]
        grad_v2 = grad_v[..., :, m:n]

        v1t_gv2 = math_ops.matmul(v1, grad_v2, adjoint_a=True)
        term2_nous -= math_ops.matmul(v1t_gv2, v2, adjoint_b=True)

      u_s_inv = math_ops.matmul(u, s_inv_mat)
      term2 = math_ops.matmul(u_s_inv, term2_nous)

      grad_a_before_transpose = term1 + term2

    if use_adjoint:
      grad_a = array_ops.matrix_transpose(grad_a_before_transpose)
    else:
      grad_a = grad_a_before_transpose

    grad_a.set_shape(a_shape)
    return grad_a
Esempio n. 30
0
def _SvdGrad(op, grad_s, grad_u, grad_v):
    """Gradient for the singular value decomposition."""

    # The derivation for the compute_uv=False case, and most of
    # the derivation for the full_matrices=True case, are in
    # Giles' paper (see reference at top of file).  A derivation for
    # the full_matrices=False case is available at
    # https://j-towns.github.io/papers/svd-derivative.pdf
    # The derivation for complex valued SVD can be found in
    # https://re-ra.xyz/misc/complexsvd.pdf or
    # https://giggleliu.github.io/2019/04/02/einsumbp.html
    a = op.inputs[0]
    a_shape = a.get_shape().with_rank_at_least(2)
    grad_s = math_ops.cast(grad_s, a.dtype)
    grad_s_mat = array_ops.matrix_diag(grad_s)

    if not op.get_attr("compute_uv"):
        s, u, v = linalg_ops.svd(a, compute_uv=True)
        grad_a = math_ops.matmul(
            u, math_ops.matmul(grad_s_mat, v, adjoint_b=True))
        grad_a.set_shape(a_shape)
        return grad_a

    full_matrices = op.get_attr("full_matrices")

    grad_u_shape = grad_u.get_shape().with_rank_at_least(2)
    grad_v_shape = grad_v.get_shape().with_rank_at_least(2)
    m = a_shape.dims[-2].merge_with(grad_u_shape[-2])
    n = a_shape.dims[-1].merge_with(grad_v_shape[-2])
    batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with(
        grad_v_shape[:-2])
    a_shape = batch_shape.concatenate([m, n])

    m = a_shape.dims[-2].value
    n = a_shape.dims[-1].value
    # TODO(rmlarsen): Make this work with placeholders.
    if m is None or n is None:
        raise NotImplementedError(
            "SVD gradient has not been implemented for input with unknown "
            "inner matrix shape.")

    s = op.outputs[0]
    u = op.outputs[1]
    v = op.outputs[2]
    s = math_ops.cast(s, a.dtype)

    use_adjoint = False
    if m > n:
        # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the
        # Hermitian transpose of the gradient at the end.
        use_adjoint = True
        m, n = n, m
        u, v = v, u
        grad_u, grad_v = grad_v, grad_u

    with ops.control_dependencies([grad_s, grad_u, grad_v]):
        if full_matrices and abs(m - n) > 1:
            raise NotImplementedError(
                "svd gradient is not implemented for abs(m - n) > 1 "
                "when full_matrices is True")
        s_mat = array_ops.matrix_diag(s)
        s2 = math_ops.square(s)

        # NOTICE: Because of the term involving f, the gradient becomes
        # infinite (or NaN in practice) when singular values are not unique.
        # Mathematically this should not be surprising, since for (k-fold)
        # degenerate singular values, the corresponding singular vectors are
        # only defined up a (k-dimensional) subspace. In practice, this can
        # lead to numerical instability when singular values are close but not
        # exactly equal.

        s_shape = array_ops.shape(s)
        f = array_ops.matrix_set_diag(
            _SafeReciprocal(
                array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)),
            array_ops.zeros_like(s))
        s_inv_mat = array_ops.matrix_diag(_SafeReciprocal(s))

        v1 = v[..., :, :m]
        grad_v1 = grad_v[..., :, :m]

        u_gu = math_ops.matmul(u, grad_u, adjoint_a=True)
        v_gv = math_ops.matmul(v1, grad_v1, adjoint_a=True)

        f_u = f * u_gu
        f_v = f * v_gv

        term1_nouv = (grad_s_mat +
                      math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) +
                      math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v)))

        term1 = math_ops.matmul(
            u, math_ops.matmul(term1_nouv, v1, adjoint_b=True))

        if m == n:
            grad_a_before_transpose = term1
        else:
            gv1t = array_ops.matrix_transpose(grad_v1, conjugate=True)
            gv1t_v1 = math_ops.matmul(gv1t, v1)
            term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True)

            if full_matrices:
                v2 = v[..., :, m:n]
                grad_v2 = grad_v[..., :, m:n]

                v1t_gv2 = math_ops.matmul(v1, grad_v2, adjoint_a=True)
                term2_nous -= math_ops.matmul(v1t_gv2, v2, adjoint_b=True)

            u_s_inv = math_ops.matmul(u, s_inv_mat)
            term2 = math_ops.matmul(u_s_inv, term2_nous)

            grad_a_before_transpose = term1 + term2

        if a.dtype.is_complex:
            eye = _linalg.eye(s_shape[-1],
                              batch_shape=s_shape[:-1],
                              dtype=a.dtype)
            l = eye * v_gv
            term3_nouv = math_ops.matmul(s_inv_mat, _linalg.adjoint(l) - l)
            term3 = 1 / 2. * math_ops.matmul(
                u, math_ops.matmul(term3_nouv, v1, adjoint_b=True))

            grad_a_before_transpose += term3

        if use_adjoint:
            grad_a = array_ops.matrix_transpose(grad_a_before_transpose,
                                                conjugate=True)
        else:
            grad_a = grad_a_before_transpose

        grad_a.set_shape(a_shape)
        return grad_a
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
   rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
   if self._assert_proper_shapes:
     aps = linear_operator_util.assert_compatible_matrix_dimensions(self, rhs)
     rhs = control_flow_ops.with_dependencies([aps], rhs)
   return rhs / self._make_multiplier_matrix(conjugate=adjoint)
def _test_matmul_base(
    self,
    use_placeholder,
    shapes_info,
    dtype,
    adjoint,
    adjoint_arg,
    blockwise_arg,
    with_batch):
  # If batch dimensions are omitted, but there are
  # no batch dimensions for the linear operator, then
  # skip the test case. This is already checked with
  # with_batch=True.
  if not with_batch and len(shapes_info.shape) <= 2:
    return
  with self.session(graph=ops.Graph()) as sess:
    sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
    operator, mat = self.operator_and_matrix(
        shapes_info, dtype, use_placeholder=use_placeholder)
    x = self.make_x(
        operator, adjoint=adjoint, with_batch=with_batch)
    # If adjoint_arg, compute A X^H^H = A X.
    if adjoint_arg:
      op_matmul = operator.matmul(
          linalg.adjoint(x),
          adjoint=adjoint,
          adjoint_arg=adjoint_arg)
    else:
      op_matmul = operator.matmul(x, adjoint=adjoint)
    mat_matmul = math_ops.matmul(mat, x, adjoint_a=adjoint)
    if not use_placeholder:
      self.assertAllEqual(op_matmul.shape,
                          mat_matmul.shape)

    # If the operator is blockwise, test both blockwise `x` and `Tensor` `x`;
    # else test only `Tensor` `x`. In both cases, evaluate all results in a
    # single `sess.run` call to avoid re-sampling the random `x` in graph mode.
    if blockwise_arg and len(operator.operators) > 1:
      # pylint: disable=protected-access
      block_dimensions = (
          operator._block_range_dimensions() if adjoint else
          operator._block_domain_dimensions())
      block_dimensions_fn = (
          operator._block_range_dimension_tensors if adjoint else
          operator._block_domain_dimension_tensors)
      # pylint: enable=protected-access
      split_x = linear_operator_util.split_arg_into_blocks(
          block_dimensions,
          block_dimensions_fn,
          x, axis=-2)
      if adjoint_arg:
        split_x = [linalg.adjoint(y) for y in split_x]
      split_matmul = operator.matmul(
          split_x, adjoint=adjoint, adjoint_arg=adjoint_arg)

      self.assertEqual(len(split_matmul), len(operator.operators))
      split_matmul = linear_operator_util.broadcast_matrix_batch_dims(
          split_matmul)
      fused_block_matmul = array_ops.concat(split_matmul, axis=-2)
      op_matmul_v, mat_matmul_v, fused_block_matmul_v = sess.run([
          op_matmul, mat_matmul, fused_block_matmul])

      # Check that the operator applied to blockwise input gives the same result
      # as matrix multiplication.
      self.assertAC(fused_block_matmul_v, mat_matmul_v)
    else:
      op_matmul_v, mat_matmul_v = sess.run([op_matmul, mat_matmul])

    # Check that the operator applied to a `Tensor` gives the same result as
    # matrix multiplication.
    self.assertAC(op_matmul_v, mat_matmul_v)
 def _matmul(self, x, adjoint=False, adjoint_arg=False):
   x = linalg.adjoint(x) if adjoint_arg else x
   if self._assert_proper_shapes:
     aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x)
     x = control_flow_ops.with_dependencies([aps], x)
   return x * self._make_multiplier_matrix(conjugate=adjoint)
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
   diag_term = math_ops.conj(self._diag) if adjoint else self._diag
   rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
   inv_diag_mat = array_ops.expand_dims(1. / diag_term, -1)
   return rhs * inv_diag_mat
  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
    # Here we follow the same use of Roth's column lemma as in `matmul`, with
    # the key difference that we replace all `matmul` instances with `solve`.
    # This follows from the property that inv(A x B) = inv(A) x inv(B).

    # Below we document the shape manipulation for adjoint=False,
    # adjoint_arg=False, but the general case of different adjoints is still
    # handled.

    if adjoint_arg:
      rhs = linalg.adjoint(rhs)

    # Always add a batch dimension to enable broadcasting to work.
    batch_shape = array_ops.concat(
        [array_ops.ones_like(self.batch_shape_tensor()), [1, 1]], 0)
    rhs += array_ops.zeros(batch_shape, dtype=rhs.dtype.base_dtype)

    # rhs has shape [B, R, C], where B represent some number of batch
    # dimensions,
    # R represents the number of rows, and C represents the number of columns.
    # In order to apply Roth's column lemma, we need to operate on a batch of
    # column vectors, so we reshape into a batch of column vectors. We put it
    # at the front to ensure that broadcasting between operators to the batch
    # dimensions B still works.
    output = _rotate_last_dim(rhs, rotate_right=True)

    # Also expand the shape to be [A, C, B, R]. The first dimension will be
    # used to accumulate dimensions from each operator matmul.
    output = output[array_ops.newaxis, ...]

    # In this loop, A is going to refer to the value of the accumulated
    # dimension. A = 1 at the start, and will end up being self.range_dimension.
    # V will refer to the last dimension. V = R at the start, and will end up
    # being 1 in the end.
    for operator in self.operators[:-1]:
      # Reshape output from [A, C, B, V] to be
      # [A, C, B, V / op.domain_dimension, op.domain_dimension]
      if adjoint:
        operator_dimension = operator.range_dimension_tensor()
      else:
        operator_dimension = operator.domain_dimension_tensor()

      output = _unvec_by(output, operator_dimension)

      # We are computing (XA^-1^T) = (A^-1 X^T)^T.
      # output has [A, C, B, V / op.domain_dimension, op.domain_dimension],
      # which is being converted to:
      # [A, C, B, V / op.domain_dimension, op.range_dimension]
      output = array_ops.matrix_transpose(output)
      output = operator.solve(output, adjoint=adjoint, adjoint_arg=False)
      output = array_ops.matrix_transpose(output)
      # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension]
      output = _rotate_last_dim(output, rotate_right=False)
      output = _vec(output)
      output = _rotate_last_dim(output, rotate_right=True)

    # After the loop, we will have
    # A = self.range_dimension / op[-1].range_dimension
    # V = op[-1].domain_dimension

    # We convert that using matvec to get:
    # [A, C, B, op[-1].range_dimension]
    output = self.operators[-1].solvevec(output, adjoint=adjoint)
    # Rearrange shape to be [B1, ... Bn, self.range_dimension, C]
    output = _rotate_last_dim(output, rotate_right=False)
    output = _vec(output)
    output = _rotate_last_dim(output, rotate_right=False)

    if rhs.shape.is_fully_defined():
      column_dim = rhs.shape[-1]
      broadcast_batch_shape = common_shapes.broadcast_shape(
          rhs.shape[:-2], self.batch_shape)
      if adjoint:
        matrix_dimensions = [self.domain_dimension, column_dim]
      else:
        matrix_dimensions = [self.range_dimension, column_dim]

      output.set_shape(broadcast_batch_shape.concatenate(
          matrix_dimensions))

    return output
  def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
    """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.

    The returned `Tensor` will be close to an exact solution if `A` is well
    conditioned. Otherwise closeness will vary. See class docstring for details.

    Given the blockwise `n + 1`-by-`n + 1` linear operator:

    op = [[A_00     0  ...     0  ...    0],
          [A_10  A_11  ...     0  ...    0],
          ...
          [A_k0  A_k1  ...  A_kk  ...    0],
          ...
          [A_n0  A_n1  ...  A_nk  ... A_nn]]

    we find `x = op.solve(y)` by observing that

    `y_k = A_k0.matmul(x_0) + A_k1.matmul(x_1) + ... + A_kk.matmul(x_k)`

    and therefore

    `x_k = A_kk.solve(y_k -
                      A_k0.matmul(x_0) - ... - A_k(k-1).matmul(x_(k-1)))`

    where `x_k` and `y_k` are the `k`th blocks obtained by decomposing `x`
    and `y` along their appropriate axes.

    We first solve `x_0 = A_00.solve(y_0)`. Proceeding inductively, we solve
    for `x_k`, `k = 1..n`, given `x_0..x_(k-1)`.

    The adjoint case is solved similarly, beginning with
    `x_n = A_nn.solve(y_n, adjoint=True)` and proceeding backwards.

    Examples:

    ```python
    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
    operator = LinearOperator(...)
    operator.shape = [..., M, N]

    # Solve R > 0 linear systems for every member of the batch.
    RHS = ... # shape [..., M, R]

    X = operator.solve(RHS)
    # X[..., :, r] is the solution to the r'th linear system
    # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r]

    operator.matmul(X)
    ==> RHS
    ```

    Args:
      rhs: `Tensor` with same `dtype` as this operator and compatible shape,
        or a list of `Tensor`s. `Tensor`s are treated like a [batch] matrices
        meaning for every set of leading dimensions, the last two dimensions
        defines a matrix.
        See class docstring for definition of compatibility.
      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
        of this `LinearOperator`:  `A^H X = rhs`.
      adjoint_arg:  Python `bool`.  If `True`, solve `A X = rhs^H` where `rhs^H`
        is the hermitian transpose (transposition and complex conjugation).
      name:  A name scope to use for ops added by this method.

    Returns:
      `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`.

    Raises:
      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
    """
    if self.is_non_singular is False:
      raise NotImplementedError(
          "Exact solve not implemented for an operator that is expected to "
          "be singular.")
    if self.is_square is False:
      raise NotImplementedError(
          "Exact solve not implemented for an operator that is expected to "
          "not be square.")
    if isinstance(rhs, linear_operator.LinearOperator):
      left_operator = self.adjoint() if adjoint else self
      right_operator = rhs.adjoint() if adjoint_arg else rhs

      if (right_operator.range_dimension is not None and
          left_operator.domain_dimension is not None and
          right_operator.range_dimension != left_operator.domain_dimension):
        raise ValueError(
            "Operators are incompatible. Expected `rhs` to have dimension"
            " {} but got {}.".format(
                left_operator.domain_dimension, right_operator.range_dimension))
      with self._name_scope(name):  # pylint: disable=not-callable
        return linear_operator_algebra.solve(left_operator, right_operator)

    with self._name_scope(name):  # pylint: disable=not-callable
      block_dimensions = (self._block_domain_dimensions() if adjoint
                          else self._block_range_dimensions())
      arg_dim = -1 if adjoint_arg else -2
      blockwise_arg = linear_operator_util.arg_is_blockwise(
          block_dimensions, rhs, arg_dim)
      if blockwise_arg:
        for i, block in enumerate(rhs):
          if not isinstance(block, linear_operator.LinearOperator):
            block = ops.convert_to_tensor_v2_with_dispatch(block)
            self._check_input_dtype(block)
            block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim])
            rhs[i] = block
        if adjoint_arg:
          split_rhs = [linalg.adjoint(y) for y in rhs]
        else:
          split_rhs = rhs

      else:
        rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
        self._check_input_dtype(rhs)
        op_dimension = (self.domain_dimension if adjoint
                        else self.range_dimension)
        op_dimension.assert_is_compatible_with(rhs.shape[arg_dim])

        rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
        split_rhs = linear_operator_util.split_arg_into_blocks(
            self._block_domain_dimensions(),
            self._block_domain_dimension_tensors,
            rhs, axis=-2)

      solution_list = []
      if adjoint:
        # For an adjoint blockwise lower-triangular linear operator, the system
        # must be solved bottom to top. Iterate backwards over rows of the
        # adjoint (i.e. columns of the non-adjoint operator).
        for index in reversed(range(len(self.operators))):
          y = split_rhs[index]
          # Iterate top to bottom over the operators in the off-diagonal portion
          # of the column-partition (i.e. row-partition of the adjoint), apply
          # the operator to the respective block of the solution found in
          # previous iterations, and subtract the result from the `rhs` block.
          # For example,let `A`, `B`, and `D` be the linear operators in the top
          # row-partition of the adjoint of
          # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])`,
          # and `x_1` and `x_2` be blocks of the solution found in previous
          # iterations of the outer loop. The following loop (when `index == 0`)
          # expresses
          # `Ax_0 + Bx_1 + Dx_2 = y_0` as `Ax_0 = y_0*`, where
          # `y_0* = y_0 - Bx_1 - Dx_2`.
          for j in reversed(range(index + 1, len(self.operators))):
            y = y - self.operators[j][index].matmul(
                solution_list[len(self.operators) - 1 - j],
                adjoint=adjoint)
          # Continuing the example above, solve `Ax_0 = y_0*` for `x_0`.
          solution_list.append(
              self._diagonal_operators[index].solve(y, adjoint=adjoint))
        solution_list.reverse()
      else:
        # Iterate top to bottom over the row-partitions.
        for row, y in zip(self.operators, split_rhs):
          # Iterate left to right over the operators in the off-diagonal portion
          # of the row-partition, apply the operator to the block of the
          # solution found in previous iterations, and subtract the result from
          # the `rhs` block. For example, let `D`, `E`, and `F` be the linear
          # operators in the bottom row-partition of
          # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])` and
          # `x_0` and `x_1` be blocks of the solution found in previous
          # iterations of the outer loop. The following loop
          # (when `index == 2`), expresses
          # `Dx_0 + Ex_1 + Fx_2 = y_2` as `Fx_2 = y_2*`, where
          # `y_2* = y_2 - D_x0 - Ex_1`.
          for i, operator in enumerate(row[:-1]):
            y = y - operator.matmul(solution_list[i], adjoint=adjoint)
          # Continuing the example above, solve `Fx_2 = y_2*` for `x_2`.
          solution_list.append(row[-1].solve(y, adjoint=adjoint))

      if blockwise_arg:
        return solution_list

      solution_list = linear_operator_util.broadcast_matrix_batch_dims(
          solution_list)
      return array_ops.concat(solution_list, axis=-2)
Esempio n. 37
0
def _SvdGrad(op, grad_s, grad_u, grad_v):
    """Gradient for Svd based on Giles' algorithm. Reference at top of file."""

    if op.get_attr("compute_uv") and not op.get_attr("full_matrices"):
        raise NotImplementedError(
            "SVD gradient is not implemented for compute_uv=True and "
            "full_matrices=False.")

    a = op.inputs[0]
    a_shape = a.get_shape().with_rank_at_least(2)

    if op.get_attr("compute_uv"):
        # TODO(rmlarsen): Make this work with complex types.
        if a.dtype.is_complex:
            raise NotImplementedError(
                "SVD gradient is not implemented for complex types and "
                "compute_uv=True.")
        grad_u_shape = grad_u.get_shape().with_rank_at_least(2)
        grad_v_shape = grad_v.get_shape().with_rank_at_least(2)
        m = a_shape[-2].merge_with(grad_u_shape[-2])
        n = a_shape[-1].merge_with(grad_v_shape[-2])
        batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with(
            grad_v_shape[:-2])
        a_shape = batch_shape.concatenate([m, n])

    m = a_shape[-2].value
    n = a_shape[-1].value
    # TODO(rmlarsen): Make this work with placeholders.
    if m is None or n is None:
        raise NotImplementedError(
            "SVD gradient has not been implemented for input with unknown "
            "inner matrix shape.")

    if not op.get_attr("full_matrices") or not op.get_attr("compute_uv"):
        s, u, v = linalg_ops.svd(a, compute_uv=True, full_matrices=True)
    else:
        s = op.outputs[0]
        u = op.outputs[1]
        v = op.outputs[2]

    use_adjoint = False
    if m > n:
        # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the
        # Hermitian transpose of the gradient at the end.
        use_adjoint = True
        m, n = n, m
        u, v = v, u
        grad_u, grad_v = grad_v, grad_u

    with ops.control_dependencies([grad_s, grad_u, grad_v]):
        grad_s_mat = array_ops.matrix_diag(grad_s)
        if not op.get_attr("compute_uv"):
            if use_adjoint:
                grad_a = math_ops.matmul(v[..., :, :m],
                                         math_ops.matmul(u, grad_s_mat),
                                         adjoint_b=True)
            else:
                grad_a = math_ops.matmul(
                    u,
                    math_ops.matmul(grad_s_mat, v[..., :, :m], adjoint_b=True))
            grad_a.set_shape(a_shape)
            return grad_a

        # TODO(rmlarsen): Define a gradient that is numerically stable for
        # abs(m-n) > 1. Currently this does not work because there are effectively
        # multiple singular values with value zero. I am not sure if this is a true
        # instability or if it simply throws off the finite difference gradient
        # checker.
        if abs(m - n) > 1:
            raise NotImplementedError(
                "svd gradient is not implemented for abs(m - n) > 1")
        s_mat = array_ops.matrix_diag(s)
        s2 = math_ops.square(s)

        # NOTICE: Because of the term involving f, the gradient becomes
        # infinite (or NaN in practice) when singular values are not unique.
        # Mathematically this should not be surprising, since for (k-fold)
        # degenerate singular values, the corresponding singular vectors are
        # only defined up a (k-dimensional) subspace. In practice, this can
        # lead to numerical instability when singular values are close but not
        # exactly equal.
        f = array_ops.matrix_set_diag(
            math_ops.reciprocal(
                array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)),
            array_ops.zeros_like(s))
        s_inv_mat = array_ops.matrix_diag(math_ops.reciprocal(s))
        u_gu = math_ops.matmul(u, grad_u, adjoint_a=True)
        v_gv = math_ops.matmul(v, grad_v, adjoint_a=True)

        if m == n:
            f_u = f * u_gu
            f_v = f * v_gv
        else:
            dv2 = array_ops.matrix_transpose(
                v_gv[..., m:n, :m]) - v_gv[..., :m, m:n]
            f_u = f * u_gu
            f_v = f * v_gv[..., :m, :m]

        grad_a_nouv = (grad_s_mat +
                       math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) +
                       math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v)))

        if m != n:
            grad_a_nouv = array_ops.concat(
                [grad_a_nouv, math_ops.matmul(s_inv_mat, dv2)], -1)

        if use_adjoint:
            # Use (U X V^H)^H = V (U X)^H.
            grad_a = math_ops.matmul(v,
                                     math_ops.matmul(u, grad_a_nouv),
                                     adjoint_b=True)
        else:
            grad_a = math_ops.matmul(
                u, math_ops.matmul(grad_a_nouv, v, adjoint_b=True))

        grad_a.set_shape(a_shape)
        return grad_a
Esempio n. 38
0
    def _matmul(self, x, adjoint=False, adjoint_arg=False):
        # Here we heavily rely on Roth's column Lemma [1]:
        # (A x B) * vec X = vec BXA^T,
        # where vec stacks all the columns of the matrix under each other. In our
        # case, x represents a batch of vec X (i.e. we think of x as a batch of
        # column vectors, rather than a matrix). Each member of the batch can be
        # reshaped to a matrix (hence we get a batch of matrices).
        # We can iteratively apply this lemma by noting that if B is a Kronecker
        # product, then we can apply the lemma again.

        # [1] W. E. Roth, "On direct product matrices,"
        # Bulletin of the American Mathematical Society, vol. 40, pp. 461-468,
        # 1934

        # Efficiency

        # Naively doing the Kronecker product, by calculating the dense matrix and
        # applying it will can take cubic time in  the size of domain_dimension
        # (assuming a square matrix). The other issue is that calculating the dense
        # matrix can be prohibitively expensive, in that it can take a large amount
        # of memory.
        #
        # This implementation avoids this memory blow up by only computing matmuls
        # with the factors. In this way, we don't have to realize the dense matrix.
        # In terms of complexity, if we have Kronecker Factors of size:
        # (n1, n1), (n2, n2), (n3, n3), ... (nJ, nJ), with N = \prod n_i, and we
        # have as input a [N, M] matrix, the naive approach would take O(N^2 M).
        # With this approach (ignoring reshaping of tensors and transposes for now),
        # the time complexity can be O(M * (\sum n_i) * N). There is also the
        # benefit of batched multiplication (In this example, the batch size is
        # roughly M * N) so this can be much faster. However, not factored in are
        # the costs of the several transposing of tensors, which can affect cache
        # behavior.

        # Below we document the shape manipulation for adjoint=False,
        # adjoint_arg=False, but the general case of different adjoints is still
        # handled.

        if adjoint_arg:
            x = linalg.adjoint(x)

        # Always add a batch dimension to enable broadcasting to work.
        batch_shape = array_ops.concat(
            [array_ops.ones_like(self.batch_shape_tensor()), [1, 1]], 0)
        x += array_ops.zeros(batch_shape, dtype=x.dtype.base_dtype)

        # x has shape [B, R, C], where B represent some number of batch dimensions,
        # R represents the number of rows, and C represents the number of columns.
        # In order to apply Roth's column lemma, we need to operate on a batch of
        # column vectors, so we reshape into a batch of column vectors. We put it
        # at the front to ensure that broadcasting between operators to the batch
        # dimensions B still works.
        output = _rotate_last_dim(x, rotate_right=True)

        # Also expand the shape to be [A, C, B, R]. The first dimension will be
        # used to accumulate dimensions from each operator matmul.
        output = output[array_ops.newaxis, ...]

        # In this loop, A is going to refer to the value of the accumulated
        # dimension. A = 1 at the start, and will end up being self.range_dimension.
        # V will refer to the last dimension. V = R at the start, and will end up
        # being 1 in the end.
        for operator in self.operators[:-1]:
            # Reshape output from [A, C, B, V] to be
            # [A, C, B, V / op.domain_dimension, op.domain_dimension]
            if adjoint:
                operator_dimension = operator.range_dimension_tensor()
            else:
                operator_dimension = operator.domain_dimension_tensor()

            output = _unvec_by(output, operator_dimension)

            # We are computing (XA^T) = (AX^T)^T.
            # output has [A, C, B, V / op.domain_dimension, op.domain_dimension],
            # which is being converted to:
            # [A, C, B, V / op.domain_dimension, op.range_dimension]
            output = array_ops.matrix_transpose(output)
            output = operator.matmul(output,
                                     adjoint=adjoint,
                                     adjoint_arg=False)
            output = array_ops.matrix_transpose(output)
            # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension]
            output = _rotate_last_dim(output, rotate_right=False)
            output = _vec(output)
            output = _rotate_last_dim(output, rotate_right=True)

        # After the loop, we will have
        # A = self.range_dimension / op[-1].range_dimension
        # V = op[-1].domain_dimension

        # We convert that using matvec to get:
        # [A, C, B, op[-1].range_dimension]
        output = self.operators[-1].matvec(output, adjoint=adjoint)
        # Rearrange shape to be [B1, ... Bn, self.range_dimension, C]
        output = _rotate_last_dim(output, rotate_right=False)
        output = _vec(output)
        output = _rotate_last_dim(output, rotate_right=False)

        if x.shape.is_fully_defined():
            column_dim = x.shape[-1]
            broadcast_batch_shape = common_shapes.broadcast_shape(
                x.shape[:-2], self.batch_shape)
            if adjoint:
                matrix_dimensions = [self.domain_dimension, column_dim]
            else:
                matrix_dimensions = [self.range_dimension, column_dim]

            output.set_shape(
                broadcast_batch_shape.concatenate(matrix_dimensions))

        return output
def _reshape_for_efficiency(a,
                            b,
                            transpose_a=False,
                            transpose_b=False,
                            adjoint_a=False,
                            adjoint_b=False):
  """Maybe reshape a, b, and return an inverse map.  For matmul/solve."""
  def identity(x):
    return x

  # At this point, we have not taken transpose/adjoint of a/b.
  still_need_to_transpose = True

  if a.shape.ndims is None or b.shape.ndims is None:
    return a, b, identity, still_need_to_transpose

  # This could be handled in the future, but seems less common.
  if a.shape.ndims >= b.shape.ndims:
    return a, b, identity, still_need_to_transpose

  # From now on, we might modify b, but will not modify a.

  # Suppose:
  #   a.shape =     C + [m, n], b.shape =
  #   b.shape = S + C + [n, r]
  b_extra_ndims = b.shape.ndims - a.shape.ndims

  # b_extra_sh = S, b_main_sh = C + [n, r]
  b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
  b_main_sh = array_ops.shape(b)[b_extra_ndims:]

  # No reason to flip unless the extra dims of b are big enough.  Why?
  # Assume adjoint/transpose = False.  Then...
  # By not flipping, we have to replicate a to shape
  #   b_extra_sh + a.shape,
  # which could use extra memory.  But in all cases, the final output has shape
  #   b_extra_sh + a.shape[:-1] + [b.shape[-1]]
  # So we only end up creating a larger object if the end dim of b is smaller
  # than the end dim of a.  This often happens, e.g. if b was a vector that was
  # expanded to a matrix (by appending a singleton).

  # Since adjoint/transpose may not be False, we must make adjustments here.
  # The dim of b that holds the multiple equations.
  a_domain_sz_ = a.shape[-2 if adjoint_a or transpose_a else -1]
  b_eq_sz_ = b.shape[-2 if adjoint_b or transpose_b else -1]
  b_extra_sz_ = (
      np.prod(b.shape[:b_extra_ndims].as_list())
      if b.shape[:b_extra_ndims].is_fully_defined() else None)
  if (a_domain_sz_ is not None and b_eq_sz_ is not None and
      b_extra_sz_ is not None):
    if b_extra_sz_ < 2 or a_domain_sz_ <= b_eq_sz_:
      return a, b, identity, still_need_to_transpose

  # At this point, we're flipping for sure!
  # Any transposes/adjoints will happen here explicitly, rather than in calling
  # code.  Why?  To avoid having to write separate complex code for each case.
  if adjoint_a:
    a = linalg.adjoint(a)
  elif transpose_a:
    a = linalg.transpose(a)
  if adjoint_b:
    b = linalg.adjoint(b)
  elif transpose_b:
    b = linalg.transpose(b)
  still_need_to_transpose = False

  # Recompute shapes, since the transpose/adjoint may have changed them.
  b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
  b_main_sh = array_ops.shape(b)[b_extra_ndims:]

  # Permutation to put the extra dims at the end.
  perm = (
      array_ops.concat(
          (math_ops.range(b_extra_ndims, b.shape.ndims),
           math_ops.range(0, b_extra_ndims)), 0))
  b_extra_on_end = array_ops.transpose(b, perm=perm)

  # Now squash this end into one long dim.
  b_squashed_end = array_ops.reshape(
      b_extra_on_end, array_ops.concat((b_main_sh[:-1], [-1]), 0))

  def reshape_inv(y):
    # Expand the extra dims hanging off the end, "b_extra_sh".
    # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y
    # Could have different batch dims than a and b, because of broadcasting.
    y_extra_shape = array_ops.concat(
        (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0)
    y_extra_on_end = array_ops.reshape(y, y_extra_shape)
    return array_ops.transpose(
        y_extra_on_end, perm=array_ops.invert_permutation(perm))

  return a, b_squashed_end, reshape_inv, still_need_to_transpose
  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
    # Given the blockwise `n + 1`-by-`n + 1` linear operator:
    #
    # op = [[A_00     0  ...     0  ...    0],
    #       [A_10  A_11  ...     0  ...    0],
    #       ...
    #       [A_k0  A_k1  ...  A_kk  ...    0],
    #       ...
    #       [A_n0  A_n1  ...  A_nk  ... A_nn]]
    #
    # we find `x = op.solve(y)` by observing that
    #
    # `y_k = A_k0.matmul(x_0) + A_k1.matmul(x_1) + ... + A_kk.matmul(x_k)`
    #
    # and therefore
    #
    # `x_k = A_kk.solve(y_k -
    #                   A_k0.matmul(x_0) - ... - A_k(k-1).matmul(x_(k-1)))`
    #
    # where `x_k` and `y_k` are the `k`th blocks obtained by decomposing `x`
    # and `y` along their appropriate axes.
    #
    # We first solve `x_0 = A_00.solve(y_0)`. Proceeding inductively, we solve
    # for `x_k`, `k = 1..n`, given `x_0..x_(k-1)`.
    #
    # The adjoint case is solved similarly, beginning with
    # `x_n = A_nn.solve(y_n, adjoint=True)` and proceeding backwards.
    rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
    split_rhs = self._split_input_into_blocks(rhs, axis=-2)

    solution_list = []
    if adjoint:
      # For an adjoint blockwise lower-triangular linear operator, the system
      # must be solved bottom to top. Iterate backwards over rows of the adjoint
      # (i.e. columns of the non-adjoint operator).
      for index in reversed(range(len(self.operators))):
        y = split_rhs[index]
        # Iterate top to bottom over the operators in the off-diagonal portion
        # of the column-partition (i.e. row-partition of the adjoint), apply
        # the operator to the respective block of the solution found in previous
        # iterations, and subtract the result from the `rhs` block. For example,
        # let `A`, `B`, and `D` be the linear operators in the top row-partition
        # of the adjoint of
        # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])`,
        # and `x_1` and `x_2` be blocks of the solution found in previous
        # iterations of the outer loop. The following loop (when `index == 0`)
        # expresses
        # `Ax_0 + Bx_1 + Dx_2 = y_0` as `Ax_0 = y_0*`, where
        # `y_0* = y_0 - Bx_1 - Dx_2`.
        for j in reversed(range(index + 1, len(self.operators))):
          y -= self.operators[j][index].matmul(
              solution_list[len(self.operators) - 1 - j],
              adjoint=adjoint)
        # Continuing the example above, solve `Ax_0 = y_0*` for `x_0`.
        solution_list.append(
            self.operators[index][index].solve(y, adjoint=adjoint))
      solution_list.reverse()
    else:
      # Iterate top to bottom over the row-partitions.
      for row, y in zip(self.operators, split_rhs):
        # Iterate left to right over the operators in the off-diagonal portion
        # of the row-partition, apply the operator to the block of the solution
        # found in previous iterations, and subtract the result from the `rhs`
        # block. For example, let `D`, `E`, and `F` be the linear operators in
        # the bottom row-partition of
        # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])` and
        # `x_0` and `x_1` be blocks of the solution found in previous iterations
        # of the outer loop. The following loop (when `index == 2`), expresses
        # `Dx_0 + Ex_1 + Fx_2 = y_2` as `Fx_2 = y_2*`, where
        # `y_2* = y_2 - D_x0 - Ex_1`.
        for i, operator in enumerate(row[:-1]):
          y -= operator.matmul(solution_list[i], adjoint=adjoint)
        # Continuing the example above, solve `Fx_2 = y_2*` for `x_2`.
        solution_list.append(row[-1].solve(y, adjoint=adjoint))

    solution_list = linear_operator_util.broadcast_matrix_batch_dims(
        solution_list)
    return array_ops.concat(solution_list, axis=-2)
Esempio n. 41
0
    def _solve(self, rhs, adjoint=False, adjoint_arg=False):
        # Here we follow the same use of Roth's column lemma as in `matmul`, with
        # the key difference that we replace all `matmul` instances with `solve`.
        # This follows from the property that inv(A x B) = inv(A) x inv(B).

        # Below we document the shape manipulation for adjoint=False,
        # adjoint_arg=False, but the general case of different adjoints is still
        # handled.

        if adjoint_arg:
            rhs = linalg.adjoint(rhs)

        # Always add a batch dimension to enable broadcasting to work.
        batch_shape = array_ops.concat(
            [array_ops.ones_like(self.batch_shape_tensor()), [1, 1]], 0)
        rhs += array_ops.zeros(batch_shape, dtype=rhs.dtype.base_dtype)

        # rhs has shape [B, R, C], where B represent some number of batch
        # dimensions,
        # R represents the number of rows, and C represents the number of columns.
        # In order to apply Roth's column lemma, we need to operate on a batch of
        # column vectors, so we reshape into a batch of column vectors. We put it
        # at the front to ensure that broadcasting between operators to the batch
        # dimensions B still works.
        output = _rotate_last_dim(rhs, rotate_right=True)

        # Also expand the shape to be [A, C, B, R]. The first dimension will be
        # used to accumulate dimensions from each operator matmul.
        output = output[array_ops.newaxis, ...]

        # In this loop, A is going to refer to the value of the accumulated
        # dimension. A = 1 at the start, and will end up being self.range_dimension.
        # V will refer to the last dimension. V = R at the start, and will end up
        # being 1 in the end.
        for operator in self.operators[:-1]:
            # Reshape output from [A, C, B, V] to be
            # [A, C, B, V / op.domain_dimension, op.domain_dimension]
            if adjoint:
                operator_dimension = operator.range_dimension_tensor()
            else:
                operator_dimension = operator.domain_dimension_tensor()

            output = _unvec_by(output, operator_dimension)

            # We are computing (XA^-1^T) = (A^-1 X^T)^T.
            # output has [A, C, B, V / op.domain_dimension, op.domain_dimension],
            # which is being converted to:
            # [A, C, B, V / op.domain_dimension, op.range_dimension]
            output = array_ops.matrix_transpose(output)
            output = operator.solve(output, adjoint=adjoint, adjoint_arg=False)
            output = array_ops.matrix_transpose(output)
            # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension]
            output = _rotate_last_dim(output, rotate_right=False)
            output = _vec(output)
            output = _rotate_last_dim(output, rotate_right=True)

        # After the loop, we will have
        # A = self.range_dimension / op[-1].range_dimension
        # V = op[-1].domain_dimension

        # We convert that using matvec to get:
        # [A, C, B, op[-1].range_dimension]
        output = self.operators[-1].solvevec(output, adjoint=adjoint)
        # Rearrange shape to be [B1, ... Bn, self.range_dimension, C]
        output = _rotate_last_dim(output, rotate_right=False)
        output = _vec(output)
        output = _rotate_last_dim(output, rotate_right=False)

        if rhs.shape.is_fully_defined():
            column_dim = rhs.shape[-1]
            broadcast_batch_shape = common_shapes.broadcast_shape(
                rhs.shape[:-2], self.batch_shape)
            if adjoint:
                matrix_dimensions = [self.domain_dimension, column_dim]
            else:
                matrix_dimensions = [self.range_dimension, column_dim]

            output.set_shape(
                broadcast_batch_shape.concatenate(matrix_dimensions))

        return output
 def _matmul(self, x, adjoint=False, adjoint_arg=False):
   diag_term = math_ops.conj(self._diag) if adjoint else self._diag
   x = linalg.adjoint(x) if adjoint_arg else x
   diag_mat = array_ops.expand_dims(diag_term, -1)
   return diag_mat * x
Esempio n. 43
0
 def _TriangularSolve(x, r):
   """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri."""
   return _linalg.adjoint(
       linalg_ops.matrix_triangular_solve(
           r, _linalg.adjoint(x), lower=False, adjoint=False))
Esempio n. 44
0
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
     diag_term = math_ops.conj(self._diag) if adjoint else self._diag
     rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
     inv_diag_mat = array_ops.expand_dims(1. / diag_term, -1)
     return rhs * inv_diag_mat
 def _to_dense(self):
   if self.is_self_adjoint:
     return self.operator.to_dense()
   return linalg.adjoint(self.operator.to_dense())
def _test_solve_base(
    self,
    use_placeholder,
    shapes_info,
    dtype,
    adjoint,
    adjoint_arg,
    blockwise_arg,
    with_batch):
  # If batch dimensions are omitted, but there are
  # no batch dimensions for the linear operator, then
  # skip the test case. This is already checked with
  # with_batch=True.
  if not with_batch and len(shapes_info.shape) <= 2:
    return
  with self.session(graph=ops.Graph()) as sess:
    sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
    operator, mat = self.operator_and_matrix(
        shapes_info, dtype, use_placeholder=use_placeholder)
    rhs = self.make_rhs(
        operator, adjoint=adjoint, with_batch=with_batch)
    # If adjoint_arg, solve A X = (rhs^H)^H = rhs.
    if adjoint_arg:
      op_solve = operator.solve(
          linalg.adjoint(rhs),
          adjoint=adjoint,
          adjoint_arg=adjoint_arg)
    else:
      op_solve = operator.solve(
          rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
    mat_solve = linear_operator_util.matrix_solve_with_broadcast(
        mat, rhs, adjoint=adjoint)
    if not use_placeholder:
      self.assertAllEqual(op_solve.shape,
                          mat_solve.shape)

    # If the operator is blockwise, test both blockwise rhs and `Tensor` rhs;
    # else test only `Tensor` rhs. In both cases, evaluate all results in a
    # single `sess.run` call to avoid re-sampling the random rhs in graph mode.
    if blockwise_arg and len(operator.operators) > 1:
      # pylint: disable=protected-access
      block_dimensions = (
          operator._block_range_dimensions() if adjoint else
          operator._block_domain_dimensions())
      block_dimensions_fn = (
          operator._block_range_dimension_tensors if adjoint else
          operator._block_domain_dimension_tensors)
      # pylint: enable=protected-access
      split_rhs = linear_operator_util.split_arg_into_blocks(
          block_dimensions,
          block_dimensions_fn,
          rhs, axis=-2)
      if adjoint_arg:
        split_rhs = [linalg.adjoint(y) for y in split_rhs]
      split_solve = operator.solve(
          split_rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
      self.assertEqual(len(split_solve), len(operator.operators))
      split_solve = linear_operator_util.broadcast_matrix_batch_dims(
          split_solve)
      fused_block_solve = array_ops.concat(split_solve, axis=-2)
      op_solve_v, mat_solve_v, fused_block_solve_v = sess.run([
          op_solve, mat_solve, fused_block_solve])

      # Check that the operator and matrix give the same solution when the rhs
      # is blockwise.
      self.assertAC(mat_solve_v, fused_block_solve_v)
    else:
      op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve])

    # Check that the operator and matrix give the same solution when the rhs is
    # a `Tensor`.
    self.assertAC(op_solve_v, mat_solve_v)
Esempio n. 47
0
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
     rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
     return linear_operator_util.matrix_triangular_solve_with_broadcast(
         array_ops.matrix_set_diag(self._tril, math_ops.exp(self._diag)), rhs, lower=True, adjoint=adjoint)
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
   rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
   return linalg_ops.matrix_triangular_solve(
       self._tril, rhs, lower=True, adjoint=adjoint)
Esempio n. 49
0
 def _matmul(self, x, adjoint=False, adjoint_arg=False):
     diag_term = math_ops.conj(self._diag) if adjoint else self._diag
     x = linalg.adjoint(x) if adjoint_arg else x
     diag_mat = array_ops.expand_dims(diag_term, -1)
     return diag_mat * x