Ejemplo n.º 1
0
    def _check_shapes(self):
        """Static check that shapes are compatible."""
        # Broadcast shape also checks that u and v are compatible.
        uv_shape = _ops.broadcast_static_shape(
            tensor_shape.TensorShape(self.u.shape),
            tensor_shape.TensorShape(self.v.shape))

        batch_shape = _ops.broadcast_static_shape(
            self.base_operator.batch_shape, uv_shape[:-2])

        tensor_shape.Dimension(
            self.base_operator.domain_dimension).assert_is_compatible_with(
                uv_shape[-2])

        if self._diag_update is not None:
            tensor_shape.dimension_at_index(
                uv_shape, -1).assert_is_compatible_with(
                    tensor_shape.TensorShape(self._diag_update.shape)[-1])
            _ops.broadcast_static_shape(
                batch_shape,
                tensor_shape.TensorShape(self._diag_update.shape)[:-1])
Ejemplo n.º 2
0
    def batch_shape(self):
        """`TensorShape` of batch dimensions of this `LinearOperator`.

    If this operator acts like the batch matrix `A` with
    `tensor_shape.TensorShape(A.shape) = [B1,...,Bb, M, N]`, then this returns
    `TensorShape([B1,...,Bb])`, equivalent to `tensor_shape.TensorShape(A.shape)[:-2]`

    Returns:
      `TensorShape`, statically determined, may be undefined.
    """
        # Derived classes get this "for free" once .shape is implemented.
        return tensor_shape.TensorShape(self.shape)[:-2]
    def _shape_tensor(self):
        # Avoid messy broadcasting if possible.
        if tensor_shape.TensorShape(self.shape).is_fully_defined():
            return ops.convert_to_tensor(tensor_shape.TensorShape(
                self.shape).as_list(),
                                         dtype=dtypes.int32,
                                         name="shape")

        # Don't check the matrix dimensions.  That would add unnecessary Asserts to
        # the graph.  Things will fail at runtime naturally if shapes are
        # incompatible.
        matrix_shape = array_ops.stack([
            self.operators[0].range_dimension_tensor(),
            self.operators[-1].domain_dimension_tensor()
        ])

        # Dummy Tensor of zeros.  Will never be materialized.
        zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor())
        for operator in self.operators[1:]:
            zeros += array_ops.zeros(shape=operator.batch_shape_tensor())
        batch_shape = array_ops.shape(zeros)

        return array_ops.concat((batch_shape, matrix_shape), 0)
Ejemplo n.º 4
0
  def _shape(self):
    # Get final matrix shape.
    domain_dimension = sum(self._block_domain_dimensions())
    range_dimension = sum(self._block_range_dimensions())
    matrix_shape = tensor_shape.TensorShape([domain_dimension, range_dimension])

    # Get broadcast batch shape.
    # broadcast_shape checks for compatibility.
    batch_shape = self.operators[0].batch_shape
    for operator in self.operators[1:]:
      batch_shape = common_shapes.broadcast_shape(
          batch_shape, operator.batch_shape)

    return batch_shape.concatenate(matrix_shape)
Ejemplo n.º 5
0
    def _trace(self):
        # The diagonal of the [[nested] block] circulant operator is the mean of
        # the spectrum.
        # Proof:  For the [0,...,0] element, this follows from the IDFT formula.
        # Then the result follows since all diagonal elements are the same.

        # Therefore, the trace is the sum of the spectrum.

        # Get shape of diag along with the axis over which to reduce the spectrum.
        # We will reduce the spectrum over all block indices.
        if tensor_shape.TensorShape(self.spectrum.shape).is_fully_defined():
            spec_rank = tensor_shape.TensorShape(self.spectrum.shape).ndims
            axis = np.arange(spec_rank - self.block_depth,
                             spec_rank,
                             dtype=np.int32)
        else:
            spec_rank = array_ops.rank(self.spectrum)
            axis = array_ops.range(spec_rank - self.block_depth, spec_rank)

        # Real diag part "re_d".
        # Suppose tensor_shape.TensorShape(spectrum.shape) = [B1,...,Bb, N1, N2]
        # tensor_shape.TensorShape(self.shape) = [B1,...,Bb, N, N], with N1 * N2 = N.
        # tensor_shape.TensorShape(re_d_value.shape) = [B1,...,Bb]
        re_d_value = math_ops.reduce_sum(math_ops.real(self.spectrum),
                                         axis=axis)

        if not np.issubdtype(self.dtype, np.complexfloating):
            return _ops.cast(re_d_value, self.dtype)

        # Imaginary part, "im_d".
        if self.is_self_adjoint:
            im_d_value = array_ops.zeros_like(re_d_value)
        else:
            im_d_value = math_ops.reduce_sum(math_ops.imag(self.spectrum),
                                             axis=axis)

        return _ops.cast(math_ops.complex(re_d_value, im_d_value), self.dtype)
Ejemplo n.º 6
0
    def _unblockify_then_matricize(self, vec):
        """Flatten the block dimensions then reshape to a batch matrix."""
        # Suppose
        #   tensor_shape.TensorShape(vec.shape) = [v0, v1, v2, v3],
        #   self.block_depth = 2.
        # Then
        #   leading shape = [v0, v1]
        #   block shape = [v2, v3].
        # We will reshape vec to
        #   [v1, v2*v3, v0].

        # Un-blockify: Flatten block dimensions.  Reshape
        #   [v0, v1, v2, v3] --> [v0, v1, v2*v3].
        if tensor_shape.TensorShape(vec.shape).is_fully_defined():
            # vec_shape = [v0, v1, v2, v3]
            vec_shape = tensor_shape.TensorShape(vec.shape).as_list()
            # vec_leading_shape = [v0, v1]
            vec_leading_shape = vec_shape[:-self.block_depth]
            # vec_block_shape = [v2, v3]
            vec_block_shape = vec_shape[-self.block_depth:]
            # flat_shape = [v0, v1, v2*v3]
            flat_shape = vec_leading_shape + [np.prod(vec_block_shape)]
        else:
            vec_shape = prefer_static.shape(vec)
            vec_leading_shape = vec_shape[:-self.block_depth]
            vec_block_shape = vec_shape[-self.block_depth:]
            flat_shape = prefer_static.concat(
                (vec_leading_shape, [math_ops.reduce_prod(vec_block_shape)]),
                0)
        vec_flat = array_ops.reshape(vec, flat_shape)

        # Matricize:  Reshape to batch matrix.
        #   [v0, v1, v2*v3] --> [v1, v2*v3, v0],
        # representing a shape [v1] batch of [v2*v3, v0] matrices.
        matrix = distribution_util.rotate_transpose(vec_flat, shift=-1)
        return matrix
Ejemplo n.º 7
0
    def matvec(self, x, adjoint=False, name="matvec"):
        """Transform [batch] vector `x` with left multiplication:  `x --> Ax`.

    ```python
    # Make an operator acting like batch matrix A.  Assume tensor_shape.TensorShape(A.shape) = [..., M, N]
    operator = LinearOperator(...)

    X = ... # shape [..., N], batch vector

    Y = operator.matvec(X)
    tensor_shape.TensorShape(Y.shape)
    ==> [..., M]

    Y[..., :] = sum_j A[..., :, j] X[..., j]
    ```

    Args:
      x: `Tensor` with compatible shape and same `dtype` as `self`.
        `x` is treated as a [batch] vector meaning for every set of leading
        dimensions, the last dimension defines a vector.
        See class docstring for definition of compatibility.
      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
      name:  A name for this `Op`.

    Returns:
      A `Tensor` with shape `[..., M]` and same `dtype` as `self`.
    """
        with self._name_scope(name):
            x = ops.convert_to_tensor(x, name="x")
            # self._check_input_dtype(x)
            self_dim = -2 if adjoint else -1
            tensor_shape.dimension_at_index(
                tensor_shape.TensorShape(self.shape),
                self_dim).assert_is_compatible_with(
                    tensor_shape.TensorShape(x.shape)[-1])
            return self._matvec(x, adjoint=adjoint)
Ejemplo n.º 8
0
    def tensor_rank(self, name="tensor_rank"):
        """Rank (in the sense of tensors) of matrix corresponding to this operator.

    If this operator acts like the batch matrix `A` with
    `tensor_shape.TensorShape(A.shape) = [B1,...,Bb, M, N]`, then this returns `b + 2`.

    Args:
      name:  A name for this `Op`.

    Returns:
      Python integer, or None if the tensor rank is undefined.
    """
        # Derived classes get this "for free" once .shape() is implemented.
        with self._name_scope(name):
            return tensor_shape.TensorShape(self.shape).ndims
Ejemplo n.º 9
0
 def _eigvals(self):
   # This will be the kronecker product of all the eigenvalues.
   # Note: It doesn't matter which kronecker product it is, since every
   # kronecker product of the same matrices are similar.
   eigvals = [operator.eigvals() for operator in self.operators]
   # Now compute the kronecker product
   product = eigvals[0]
   for eigval in eigvals[1:]:
     # Product has shape [B, R1, 1].
     product = product[..., _ops.newaxis]
     # Eigval has shape [B, 1, R2]. Produces shape [B, R1, R2].
     product = product * eigval[..., _ops.newaxis, :]
     # Reshape to [B, R1 * R2]
     product = array_ops.reshape(
         product,
         shape=prefer_static.concat([prefer_static.shape(product)[:-2], [-1]], axis=0))
   tensorshape_util.set_shape(product, tensor_shape.TensorShape(self.shape)[:-1])
   return product
  def _shape(self):
    # Get final matrix shape.
    domain_dimension = self.operators[0].domain_dimension
    for operator in self.operators[1:]:
      domain_dimension.assert_is_compatible_with(operator.range_dimension)
      domain_dimension = operator.domain_dimension

    matrix_shape = tensor_shape.TensorShape(
        [self.operators[0].range_dimension,
         self.operators[-1].domain_dimension])

    # Get broadcast batch shape.
    # broadcast_shape checks for compatibility.
    batch_shape = self.operators[0].batch_shape
    for operator in self.operators[1:]:
      batch_shape = common_shapes.broadcast_shape(
          batch_shape, operator.batch_shape)

    return batch_shape.concatenate(matrix_shape)
Ejemplo n.º 11
0
    def _to_dense(self):
        product = self.operators[0].to_dense()
        for operator in self.operators[1:]:
            # Product has shape [B, R1, 1, C1, 1].
            product = product[..., :, _ops.newaxis, :, _ops.newaxis]
            # Operator has shape [B, 1, R2, 1, C2].
            op_to_mul = operator.to_dense()[..., _ops.newaxis, :,
                                            _ops.newaxis, :]
            # This is now [B, R1, R2, C1, C2].
            product = product * op_to_mul
            # Now merge together dimensions to get [B, R1 * R2, C1 * C2].
            product_shape = _prefer_static_shape(product)
            shape = _prefer_static_concat_shape(product_shape[:-4], [
                product_shape[-4] * product_shape[-3],
                product_shape[-2] * product_shape[-1]
            ])

            product = array_ops.reshape(product, shape=shape)
        tensorshape_util.set_shape(product,
                                   tensor_shape.TensorShape(self.shape))
        return product
  def _to_dense(self):
    num_cols = 0
    dense_rows = []
    flat_broadcast_operators = linear_operator_util.broadcast_matrix_batch_dims(
        [op.to_dense() for row in self.operators for op in row])  # pylint: disable=g-complex-comprehension
    broadcast_operators = [
        flat_broadcast_operators[i * (i + 1) // 2:(i + 1) * (i + 2) // 2]
        for i in range(len(self.operators))]
    for row_blocks in broadcast_operators:
      batch_row_shape = prefer_static.shape(row_blocks[0])[:-1]
      num_cols = num_cols + prefer_static.shape(row_blocks[-1])[-1]
      zeros_to_pad_after_shape = prefer_static.concat(
          [batch_row_shape,
           [self.domain_dimension_tensor() - num_cols]], axis=-1)
      zeros_to_pad_after = array_ops.zeros(
          shape=zeros_to_pad_after_shape, dtype=self.dtype)

      row_blocks.append(zeros_to_pad_after)
      dense_rows.append(prefer_static.concat(row_blocks, axis=-1))

    mat = prefer_static.concat(dense_rows, axis=-2)
    tensorshape_util.set_shape(mat, tensor_shape.TensorShape(self.shape))
    return mat
Ejemplo n.º 13
0
def _broadcast_static_shape(shape_x, shape_y):
  """Reimplements `tf.broadcast_static_shape` in JAX/NumPy."""
  if (tensor_shape.TensorShape(shape_x).ndims is None or
      tensor_shape.TensorShape(shape_y).ndims is None):
    return tensor_shape.TensorShape(None)
  shape_x = tuple(tensor_shape.TensorShape(shape_x).as_list())
  shape_y = tuple(tensor_shape.TensorShape(shape_y).as_list())
  try:
    if JAX_MODE:
      error_message = 'Incompatible shapes for broadcasting'
      return tensor_shape.TensorShape(lax.broadcast_shapes(shape_x, shape_y))
    error_message = ('shape mismatch: objects cannot be broadcast to'
                     ' a single shape')
    return tensor_shape.TensorShape(
        np.broadcast(np.zeros(shape_x), np.zeros(shape_y)).shape)
  except ValueError as e:
    # Match TF error message
    if error_message in str(e):
      raise ValueError(
          'Incompatible shapes for broadcasting: {} and {}'.format(
              shape_x, shape_y))
    raise
Ejemplo n.º 14
0
 def _to_dense(self):
     product = self.operators[0].to_dense()
     for operator in self.operators[1:]:
         # Product has shape [B, R1, 1, C1, 1].
         product = product[..., :, _ops.newaxis, :, _ops.newaxis]
         # Operator has shape [B, 1, R2, 1, C2].
         op_to_mul = operator.to_dense()[..., _ops.newaxis, :,
                                         _ops.newaxis, :]
         # This is now [B, R1, R2, C1, C2].
         product *= op_to_mul
         # Now merge together dimensions to get [B, R1 * R2, C1 * C2].
         product = array_ops.reshape(product,
                                     shape=array_ops.concat([
                                         array_ops.shape(product)[:-4],
                                         [
                                             array_ops.shape(product)[-4] *
                                             array_ops.shape(product)[-3],
                                             array_ops.shape(product)[-2] *
                                             array_ops.shape(product)[-1]
                                         ]
                                     ],
                                                            axis=0))
     product.set_shape(tensor_shape.TensorShape(self.shape))
     return product
    def _solve(self, rhs, adjoint=False, adjoint_arg=False):
        # Here we follow the same use of Roth's column lemma as in `matmul`, with
        # the key difference that we replace all `matmul` instances with `solve`.
        # This follows from the property that inv(A x B) = inv(A) x inv(B).

        # Below we document the shape manipulation for adjoint=False,
        # adjoint_arg=False, but the general case of different adjoints is still
        # handled.

        if adjoint_arg:
            rhs = linalg.adjoint(rhs)

        # Always add a batch dimension to enable broadcasting to work.
        batch_shape = self._compute_ones_matrix_shape()
        rhs = rhs + array_ops.zeros(batch_shape, dtype=rhs.dtype)

        # rhs has shape [B, R, C], where B represent some number of batch
        # dimensions,
        # R represents the number of rows, and C represents the number of columns.
        # In order to apply Roth's column lemma, we need to operate on a batch of
        # column vectors, so we reshape into a batch of column vectors. We put it
        # at the front to ensure that broadcasting between operators to the batch
        # dimensions B still works.
        output = _rotate_last_dim(rhs, rotate_right=True)

        # Also expand the shape to be [A, C, B, R]. The first dimension will be
        # used to accumulate dimensions from each operator matmul.
        output = output[_ops.newaxis, ...]

        # In this loop, A is going to refer to the value of the accumulated
        # dimension. A = 1 at the start, and will end up being self.range_dimension.
        # V will refer to the last dimension. V = R at the start, and will end up
        # being 1 in the end.
        for operator in self.operators[:-1]:
            # Reshape output from [A, C, B, V] to be
            # [A, C, B, V / op.domain_dimension, op.domain_dimension]
            if adjoint:
                operator_dimension = operator.range_dimension_tensor()
            else:
                operator_dimension = operator.domain_dimension_tensor()

            output = _unvec_by(output, operator_dimension)

            # We are computing (XA^-1^T) = (A^-1 X^T)^T.
            # output has [A, C, B, V / op.domain_dimension, op.domain_dimension],
            # which is being converted to:
            # [A, C, B, V / op.domain_dimension, op.range_dimension]
            output = _linalg.matrix_transpose(output)
            output = operator.solve(output, adjoint=adjoint, adjoint_arg=False)
            output = _linalg.matrix_transpose(output)
            # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension]
            output = _rotate_last_dim(output, rotate_right=False)
            output = _vec(output)
            output = _rotate_last_dim(output, rotate_right=True)

        # After the loop, we will have
        # A = self.range_dimension / op[-1].range_dimension
        # V = op[-1].domain_dimension

        # We convert that using matvec to get:
        # [A, C, B, op[-1].range_dimension]
        output = self.operators[-1].solvevec(output, adjoint=adjoint)
        # Rearrange shape to be [B1, ... Bn, self.range_dimension, C]
        output = _rotate_last_dim(output, rotate_right=False)
        output = _vec(output)
        output = _rotate_last_dim(output, rotate_right=False)

        if tensor_shape.TensorShape(rhs.shape).is_fully_defined():
            column_dim = tensor_shape.TensorShape(rhs.shape)[-1]
            broadcast_batch_shape = common_shapes.broadcast_shape(
                tensor_shape.TensorShape(rhs.shape)[:-2], self.batch_shape)
            if adjoint:
                matrix_dimensions = [self.domain_dimension, column_dim]
            else:
                matrix_dimensions = [self.range_dimension, column_dim]

            tensorshape_util.set_shape(
                output, broadcast_batch_shape.concatenate(matrix_dimensions))

        return output
    def _matmul(self, x, adjoint=False, adjoint_arg=False):
        # Here we heavily rely on Roth's column Lemma [1]:
        # (A x B) * vec X = vec BXA^T,
        # where vec stacks all the columns of the matrix under each other. In our
        # case, x represents a batch of vec X (i.e. we think of x as a batch of
        # column vectors, rather than a matrix). Each member of the batch can be
        # reshaped to a matrix (hence we get a batch of matrices).
        # We can iteratively apply this lemma by noting that if B is a Kronecker
        # product, then we can apply the lemma again.

        # [1] W. E. Roth, "On direct product matrices,"
        # Bulletin of the American Mathematical Society, vol. 40, pp. 461-468,
        # 1934

        # Efficiency

        # Naively doing the Kronecker product, by calculating the dense matrix and
        # applying it will can take cubic time in  the size of domain_dimension
        # (assuming a square matrix). The other issue is that calculating the dense
        # matrix can be prohibitively expensive, in that it can take a large amount
        # of memory.
        #
        # This implementation avoids this memory blow up by only computing matmuls
        # with the factors. In this way, we don't have to realize the dense matrix.
        # In terms of complexity, if we have Kronecker Factors of size:
        # (n1, n1), (n2, n2), (n3, n3), ... (nJ, nJ), with N = \prod n_i, and we
        # have as input a [N, M] matrix, the naive approach would take O(N^2 M).
        # With this approach (ignoring reshaping of tensors and transposes for now),
        # the time complexity can be O(M * (\sum n_i) * N). There is also the
        # benefit of batched multiplication (In this example, the batch size is
        # roughly M * N) so this can be much faster. However, not factored in are
        # the costs of the several transposing of tensors, which can affect cache
        # behavior.

        # Below we document the shape manipulation for adjoint=False,
        # adjoint_arg=False, but the general case of different adjoints is still
        # handled.

        if adjoint_arg:
            x = linalg.adjoint(x)

        # Always add a batch dimension to enable broadcasting to work.
        batch_shape = self._compute_ones_matrix_shape()
        x = x + array_ops.zeros(batch_shape, dtype=x.dtype)

        # x has shape [B, R, C], where B represent some number of batch dimensions,
        # R represents the number of rows, and C represents the number of columns.
        # In order to apply Roth's column lemma, we need to operate on a batch of
        # column vectors, so we reshape into a batch of column vectors. We put it
        # at the front to ensure that broadcasting between operators to the batch
        # dimensions B still works.
        output = _rotate_last_dim(x, rotate_right=True)

        # Also expand the shape to be [A, C, B, R]. The first dimension will be
        # used to accumulate dimensions from each operator matmul.
        output = output[_ops.newaxis, ...]

        # In this loop, A is going to refer to the value of the accumulated
        # dimension. A = 1 at the start, and will end up being self.range_dimension.
        # V will refer to the last dimension. V = R at the start, and will end up
        # being 1 in the end.
        for operator in self.operators[:-1]:
            # Reshape output from [A, C, B, V] to be
            # [A, C, B, V / op.domain_dimension, op.domain_dimension]
            if adjoint:
                operator_dimension = operator.range_dimension_tensor()
            else:
                operator_dimension = operator.domain_dimension_tensor()

            output = _unvec_by(output, operator_dimension)

            # We are computing (XA^T) = (AX^T)^T.
            # output has [A, C, B, V / op.domain_dimension, op.domain_dimension],
            # which is being converted to:
            # [A, C, B, V / op.domain_dimension, op.range_dimension]
            output = _linalg.matrix_transpose(output)
            output = operator.matmul(output,
                                     adjoint=adjoint,
                                     adjoint_arg=False)
            output = _linalg.matrix_transpose(output)
            # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension]
            output = _rotate_last_dim(output, rotate_right=False)
            output = _vec(output)
            output = _rotate_last_dim(output, rotate_right=True)

        # After the loop, we will have
        # A = self.range_dimension / op[-1].range_dimension
        # V = op[-1].domain_dimension

        # We convert that using matvec to get:
        # [A, C, B, op[-1].range_dimension]
        output = self.operators[-1].matvec(output, adjoint=adjoint)
        # Rearrange shape to be [B1, ... Bn, self.range_dimension, C]
        output = _rotate_last_dim(output, rotate_right=False)
        output = _vec(output)
        output = _rotate_last_dim(output, rotate_right=False)

        if tensor_shape.TensorShape(x.shape).is_fully_defined():
            column_dim = tensor_shape.TensorShape(x.shape)[-1]
            broadcast_batch_shape = common_shapes.broadcast_shape(
                tensor_shape.TensorShape(x.shape)[:-2], self.batch_shape)
            if adjoint:
                matrix_dimensions = [self.domain_dimension, column_dim]
            else:
                matrix_dimensions = [self.range_dimension, column_dim]

            tensorshape_util.set_shape(
                output, broadcast_batch_shape.concatenate(matrix_dimensions))

        return output
    def _shape(self):
        matrix_shape = tensor_shape.TensorShape(
            (self._num_rows_static, self._num_rows_static))

        batch_shape = tensor_shape.TensorShape(self.multiplier.shape)
        return batch_shape.concatenate(matrix_shape)
Ejemplo n.º 18
0
 def _shape(self):
     batch_shape = _ops.broadcast_static_shape(
         self.base_operator.batch_shape,
         tensor_shape.TensorShape(self.u.shape)[:-2])
     return batch_shape.concatenate(
         tensor_shape.TensorShape(self.base_operator.shape)[-2:])
Ejemplo n.º 19
0
    def solvevec(self, rhs, adjoint=False, name="solve"):
        """Solve single equation with best effort: `A X = rhs`.

    The returned `Tensor` will be close to an exact solution if `A` is well
    conditioned. Otherwise closeness will vary. See class docstring for details.

    Examples:

    ```python
    # Make an operator acting like batch matrix A.  Assume tensor_shape.TensorShape(A.shape) = [..., M, N]
    operator = LinearOperator(...)
    tensor_shape.TensorShape(operator.shape) = [..., M, N]

    # Solve one linear system for every member of the batch.
    RHS = ... # shape [..., M]

    X = operator.solvevec(RHS)
    # X is the solution to the linear system
    # sum_j A[..., :, j] X[..., j] = RHS[..., :]

    operator.matvec(X)
    ==> RHS
    ```

    Args:
      rhs: `Tensor` with same `dtype` as this operator, or list of `Tensor`s
        (for blockwise operators). `Tensor`s are treated as [batch] vectors,
        meaning for every set of leading dimensions, the last dimension defines
        a vector.  See class docstring for definition of compatibility regarding
        batch dimensions.
      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
        of this `LinearOperator`:  `A^H X = rhs`.
      name:  A name scope to use for ops added by this method.

    Returns:
      `Tensor` with shape `[...,N]` and same `dtype` as `rhs`.

    Raises:
      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
    """
        with self._name_scope(name):
            block_dimensions = (self._block_domain_dimensions()
                                if adjoint else self._block_range_dimensions())
            if linear_operator_util.arg_is_blockwise(block_dimensions, rhs,
                                                     -1):
                for i, block in enumerate(rhs):
                    if not isinstance(block, linear_operator.LinearOperator):
                        block = ops.convert_to_tensor(block)
                        # self._check_input_dtype(block)
                        block_dimensions[i].assert_is_compatible_with(
                            tensor_shape.TensorShape(block.shape)[-1])
                        rhs[i] = block
                rhs_mat = [
                    array_ops.expand_dims(block, axis=-1) for block in rhs
                ]
                solution_mat = self.solve(rhs_mat, adjoint=adjoint)
                return [array_ops.squeeze(x, axis=-1) for x in solution_mat]

            rhs = ops.convert_to_tensor(rhs, name="rhs")
            # self._check_input_dtype(rhs)
            op_dimension = (self.domain_dimension
                            if adjoint else self.range_dimension)
            op_dimension.assert_is_compatible_with(
                tensor_shape.TensorShape(rhs.shape)[-1])
            rhs_mat = array_ops.expand_dims(rhs, axis=-1)
            solution_mat = self.solve(rhs_mat, adjoint=adjoint)
            return array_ops.squeeze(solution_mat, axis=-1)
Ejemplo n.º 20
0
    def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
        """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.

    The returned `Tensor` will be close to an exact solution if `A` is well
    conditioned. Otherwise closeness will vary. See class docstring for details.

    Examples:

    ```python
    # Make an operator acting like batch matrix A.  Assume tensor_shape.TensorShape(A.shape) = [..., M, N]
    operator = LinearOperator(...)
    tensor_shape.TensorShape(operator.shape) = [..., M, N]

    # Solve R > 0 linear systems for every member of the batch.
    RHS = ... # shape [..., M, R]

    X = operator.solve(RHS)
    # X[..., :, r] is the solution to the r'th linear system
    # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r]

    operator.matmul(X)
    ==> RHS
    ```

    Args:
      rhs: `Tensor` with same `dtype` as this operator and compatible shape,
        or a list of `Tensor`s (for blockwise operators). `Tensor`s are treated
        like a [batch] matrices meaning for every set of leading dimensions, the
        last two dimensions defines a matrix.
        See class docstring for definition of compatibility.
      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
        of this `LinearOperator`:  `A^H X = rhs`.
      adjoint_arg:  Python `bool`.  If `True`, solve `A X = rhs^H` where `rhs^H`
        is the hermitian transpose (transposition and complex conjugation).
      name:  A name scope to use for ops added by this method.

    Returns:
      `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`.

    Raises:
      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
    """
        if self.is_non_singular is False:
            raise NotImplementedError(
                "Exact solve not implemented for an operator that is expected to "
                "be singular.")
        if self.is_square is False:
            raise NotImplementedError(
                "Exact solve not implemented for an operator that is expected to "
                "not be square.")
        if isinstance(rhs, linear_operator.LinearOperator):
            left_operator = self.adjoint() if adjoint else self
            right_operator = rhs.adjoint() if adjoint_arg else rhs

            if (right_operator.range_dimension is not None
                    and left_operator.domain_dimension is not None
                    and right_operator.range_dimension !=
                    left_operator.domain_dimension):
                raise ValueError(
                    "Operators are incompatible. Expected `rhs` to have dimension"
                    " {} but got {}.".format(left_operator.domain_dimension,
                                             right_operator.range_dimension))
            with self._name_scope(name):
                return linear_operator_algebra.solve(left_operator,
                                                     right_operator)

        with self._name_scope(name):
            block_dimensions = (self._block_domain_dimensions()
                                if adjoint else self._block_range_dimensions())
            arg_dim = -1 if adjoint_arg else -2
            blockwise_arg = linear_operator_util.arg_is_blockwise(
                block_dimensions, rhs, arg_dim)

            if blockwise_arg:
                split_rhs = rhs
                for i, block in enumerate(split_rhs):
                    if not isinstance(block, linear_operator.LinearOperator):
                        block = ops.convert_to_tensor(block)
                        # self._check_input_dtype(block)
                        block_dimensions[i].assert_is_compatible_with(
                            tensor_shape.TensorShape(block.shape)[arg_dim])
                        split_rhs[i] = block
            else:
                rhs = ops.convert_to_tensor(rhs, name="rhs")
                # self._check_input_dtype(rhs)
                op_dimension = (self.domain_dimension
                                if adjoint else self.range_dimension)
                op_dimension.assert_is_compatible_with(
                    tensor_shape.TensorShape(rhs.shape)[arg_dim])
                split_dim = -1 if adjoint_arg else -2
                # Split input by rows normally, and otherwise columns.
                split_rhs = linear_operator_util.split_arg_into_blocks(
                    self._block_domain_dimensions(),
                    self._block_domain_dimension_tensors,
                    rhs,
                    axis=split_dim)

            solution_list = []
            for index, operator in enumerate(self.operators):
                solution_list += [
                    operator.solve(split_rhs[index],
                                   adjoint=adjoint,
                                   adjoint_arg=adjoint_arg)
                ]

            if blockwise_arg:
                return solution_list

            solution_list = linear_operator_util.broadcast_matrix_batch_dims(
                solution_list)
            return array_ops.concat(solution_list, axis=-2)
Ejemplo n.º 21
0
    def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
        """Transform [batch] matrix `x` with left multiplication:  `x --> Ax`.

    ```python
    # Make an operator acting like batch matrix A.  Assume tensor_shape.TensorShape(A.shape) = [..., M, N]
    operator = LinearOperator(...)
    tensor_shape.TensorShape(operator.shape) = [..., M, N]

    X = ... # shape [..., N, R], batch matrix, R > 0.

    Y = operator.matmul(X)
    tensor_shape.TensorShape(Y.shape)
    ==> [..., M, R]

    Y[..., :, r] = sum_j A[..., :, j] X[j, r]
    ```

    Args:
      x: `LinearOperator`, `Tensor` with compatible shape and same `dtype` as
        `self`, or a blockwise iterable of `LinearOperator`s or `Tensor`s. See
        class docstring for definition of shape compatibility.
      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
      adjoint_arg:  Python `bool`.  If `True`, compute `A x^H` where `x^H` is
        the hermitian transpose (transposition and complex conjugation).
      name:  A name for this `Op`.

    Returns:
      A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype`
        as `self`, or if `x` is blockwise, a list of `Tensor`s with shapes that
        concatenate to `[..., M, R]`.
    """
        if isinstance(x, linear_operator.LinearOperator):
            left_operator = self.adjoint() if adjoint else self
            right_operator = x.adjoint() if adjoint_arg else x

            if (right_operator.range_dimension is not None
                    and left_operator.domain_dimension is not None
                    and right_operator.range_dimension !=
                    left_operator.domain_dimension):
                raise ValueError(
                    "Operators are incompatible. Expected `x` to have dimension"
                    " {} but got {}.".format(left_operator.domain_dimension,
                                             right_operator.range_dimension))
            with self._name_scope(name):
                return linear_operator_algebra.matmul(left_operator,
                                                      right_operator)

        with self._name_scope(name):
            arg_dim = -1 if adjoint_arg else -2
            block_dimensions = (self._block_range_dimensions() if adjoint else
                                self._block_domain_dimensions())
            if linear_operator_util.arg_is_blockwise(block_dimensions, x,
                                                     arg_dim):
                for i, block in enumerate(x):
                    if not isinstance(block, linear_operator.LinearOperator):
                        block = ops.convert_to_tensor(block)
                        # self._check_input_dtype(block)
                        block_dimensions[i].assert_is_compatible_with(
                            tensor_shape.TensorShape(block.shape)[arg_dim])
                        x[i] = block
            else:
                x = ops.convert_to_tensor(x, name="x")
                # self._check_input_dtype(x)
                op_dimension = (self.range_dimension
                                if adjoint else self.domain_dimension)
                op_dimension.assert_is_compatible_with(
                    tensor_shape.TensorShape(x.shape)[arg_dim])
            return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
    def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
        """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.

    The returned `Tensor` will be close to an exact solution if `A` is well
    conditioned. Otherwise closeness will vary. See class docstring for details.

    Given the blockwise `n + 1`-by-`n + 1` linear operator:

    op = [[A_00     0  ...     0  ...    0],
          [A_10  A_11  ...     0  ...    0],
          ...
          [A_k0  A_k1  ...  A_kk  ...    0],
          ...
          [A_n0  A_n1  ...  A_nk  ... A_nn]]

    we find `x = op.solve(y)` by observing that

    `y_k = A_k0.matmul(x_0) + A_k1.matmul(x_1) + ... + A_kk.matmul(x_k)`

    and therefore

    `x_k = A_kk.solve(y_k -
                      A_k0.matmul(x_0) - ... - A_k(k-1).matmul(x_(k-1)))`

    where `x_k` and `y_k` are the `k`th blocks obtained by decomposing `x`
    and `y` along their appropriate axes.

    We first solve `x_0 = A_00.solve(y_0)`. Proceeding inductively, we solve
    for `x_k`, `k = 1..n`, given `x_0..x_(k-1)`.

    The adjoint case is solved similarly, beginning with
    `x_n = A_nn.solve(y_n, adjoint=True)` and proceeding backwards.

    Examples:

    ```python
    # Make an operator acting like batch matrix A.  Assume tensor_shape.TensorShape(A.shape) = [..., M, N]
    operator = LinearOperator(...)
    tensor_shape.TensorShape(operator.shape) = [..., M, N]

    # Solve R > 0 linear systems for every member of the batch.
    RHS = ... # shape [..., M, R]

    X = operator.solve(RHS)
    # X[..., :, r] is the solution to the r'th linear system
    # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r]

    operator.matmul(X)
    ==> RHS
    ```

    Args:
      rhs: `Tensor` with same `dtype` as this operator and compatible shape,
        or a list of `Tensor`s. `Tensor`s are treated like a [batch] matrices
        meaning for every set of leading dimensions, the last two dimensions
        defines a matrix.
        See class docstring for definition of compatibility.
      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
        of this `LinearOperator`:  `A^H X = rhs`.
      adjoint_arg:  Python `bool`.  If `True`, solve `A X = rhs^H` where `rhs^H`
        is the hermitian transpose (transposition and complex conjugation).
      name:  A name scope to use for ops added by this method.

    Returns:
      `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`.

    Raises:
      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
    """
        if self.is_non_singular is False:
            raise NotImplementedError(
                "Exact solve not implemented for an operator that is expected to "
                "be singular.")
        if self.is_square is False:
            raise NotImplementedError(
                "Exact solve not implemented for an operator that is expected to "
                "not be square.")
        if isinstance(rhs, linear_operator.LinearOperator):
            left_operator = self.adjoint() if adjoint else self
            right_operator = rhs.adjoint() if adjoint_arg else rhs

            if (right_operator.range_dimension is not None
                    and left_operator.domain_dimension is not None
                    and right_operator.range_dimension !=
                    left_operator.domain_dimension):
                raise ValueError(
                    "Operators are incompatible. Expected `rhs` to have dimension"
                    " {} but got {}.".format(left_operator.domain_dimension,
                                             right_operator.range_dimension))
            with self._name_scope(name):
                return linear_operator_algebra.solve(left_operator,
                                                     right_operator)

        with self._name_scope(name):
            block_dimensions = (self._block_domain_dimensions()
                                if adjoint else self._block_range_dimensions())
            arg_dim = -1 if adjoint_arg else -2
            blockwise_arg = linear_operator_util.arg_is_blockwise(
                block_dimensions, rhs, arg_dim)
            if blockwise_arg:
                for i, block in enumerate(rhs):
                    if not isinstance(block, linear_operator.LinearOperator):
                        block = ops.convert_to_tensor(block)
                        # self._check_input_dtype(block)
                        block_dimensions[i].assert_is_compatible_with(
                            tensor_shape.TensorShape(block.shape)[arg_dim])
                        rhs[i] = block
                if adjoint_arg:
                    split_rhs = [linalg.adjoint(y) for y in rhs]
                else:
                    split_rhs = rhs

            else:
                rhs = ops.convert_to_tensor(rhs, name="rhs")
                # self._check_input_dtype(rhs)
                op_dimension = (self.domain_dimension
                                if adjoint else self.range_dimension)
                op_dimension.assert_is_compatible_with(
                    tensor_shape.TensorShape(rhs.shape)[arg_dim])

                rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
                split_rhs = linear_operator_util.split_arg_into_blocks(
                    self._block_domain_dimensions(),
                    self._block_domain_dimension_tensors,
                    rhs,
                    axis=-2)

            solution_list = []
            if adjoint:
                # For an adjoint blockwise lower-triangular linear operator, the system
                # must be solved bottom to top. Iterate backwards over rows of the
                # adjoint (i.e. columns of the non-adjoint operator).
                for index in reversed(range(len(self.operators))):
                    y = split_rhs[index]
                    # Iterate top to bottom over the operators in the off-diagonal portion
                    # of the column-partition (i.e. row-partition of the adjoint), apply
                    # the operator to the respective block of the solution found in
                    # previous iterations, and subtract the result from the `rhs` block.
                    # For example,let `A`, `B`, and `D` be the linear operators in the top
                    # row-partition of the adjoint of
                    # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])`,
                    # and `x_1` and `x_2` be blocks of the solution found in previous
                    # iterations of the outer loop. The following loop (when `index == 0`)
                    # expresses
                    # `Ax_0 + Bx_1 + Dx_2 = y_0` as `Ax_0 = y_0*`, where
                    # `y_0* = y_0 - Bx_1 - Dx_2`.
                    for j in reversed(range(index + 1, len(self.operators))):
                        y -= self.operators[j][index].matmul(
                            solution_list[len(self.operators) - 1 - j],
                            adjoint=adjoint)
                    # Continuing the example above, solve `Ax_0 = y_0*` for `x_0`.
                    solution_list.append(self._diagonal_operators[index].solve(
                        y, adjoint=adjoint))
                solution_list.reverse()
            else:
                # Iterate top to bottom over the row-partitions.
                for row, y in zip(self.operators, split_rhs):
                    # Iterate left to right over the operators in the off-diagonal portion
                    # of the row-partition, apply the operator to the block of the
                    # solution found in previous iterations, and subtract the result from
                    # the `rhs` block. For example, let `D`, `E`, and `F` be the linear
                    # operators in the bottom row-partition of
                    # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])` and
                    # `x_0` and `x_1` be blocks of the solution found in previous
                    # iterations of the outer loop. The following loop
                    # (when `index == 2`), expresses
                    # `Dx_0 + Ex_1 + Fx_2 = y_2` as `Fx_2 = y_2*`, where
                    # `y_2* = y_2 - D_x0 - Ex_1`.
                    for i, operator in enumerate(row[:-1]):
                        y -= operator.matmul(solution_list[i], adjoint=adjoint)
                    # Continuing the example above, solve `Fx_2 = y_2*` for `x_2`.
                    solution_list.append(row[-1].solve(y, adjoint=adjoint))

            if blockwise_arg:
                return solution_list

            solution_list = linear_operator_util.broadcast_matrix_batch_dims(
                solution_list)
            return array_ops.concat(solution_list, axis=-2)
Ejemplo n.º 23
0
    def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
        """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.

    The returned `Tensor` will be close to an exact solution if `A` is well
    conditioned. Otherwise closeness will vary. See class docstring for details.

    Examples:

    ```python
    # Make an operator acting like batch matrix A.  Assume tensor_shape.TensorShape(A.shape) = [..., M, N]
    operator = LinearOperator(...)
    tensor_shape.TensorShape(operator.shape) = [..., M, N]

    # Solve R > 0 linear systems for every member of the batch.
    RHS = ... # shape [..., M, R]

    X = operator.solve(RHS)
    # X[..., :, r] is the solution to the r'th linear system
    # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r]

    operator.matmul(X)
    ==> RHS
    ```

    Args:
      rhs: `Tensor` with same `dtype` as this operator and compatible shape.
        `rhs` is treated like a [batch] matrix meaning for every set of leading
        dimensions, the last two dimensions defines a matrix.
        See class docstring for definition of compatibility.
      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
        of this `LinearOperator`:  `A^H X = rhs`.
      adjoint_arg:  Python `bool`.  If `True`, solve `A X = rhs^H` where `rhs^H`
        is the hermitian transpose (transposition and complex conjugation).
      name:  A name scope to use for ops added by this method.

    Returns:
      `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`.

    Raises:
      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
    """
        if self.is_non_singular is False:
            raise NotImplementedError(
                "Exact solve not implemented for an operator that is expected to "
                "be singular.")
        if self.is_square is False:
            raise NotImplementedError(
                "Exact solve not implemented for an operator that is expected to "
                "not be square.")
        if isinstance(rhs, LinearOperator):
            left_operator = self.adjoint() if adjoint else self
            right_operator = rhs.adjoint() if adjoint_arg else rhs

            if (right_operator.range_dimension is not None
                    and left_operator.domain_dimension is not None
                    and right_operator.range_dimension !=
                    left_operator.domain_dimension):
                raise ValueError(
                    "Operators are incompatible. Expected `rhs` to have dimension"
                    " {} but got {}.".format(left_operator.domain_dimension,
                                             right_operator.range_dimension))
            with self._name_scope(name):
                return linear_operator_algebra.solve(left_operator,
                                                     right_operator)

        with self._name_scope(name):
            rhs = ops.convert_to_tensor(rhs, name="rhs")
            # self._check_input_dtype(rhs)

            self_dim = -1 if adjoint else -2
            arg_dim = -1 if adjoint_arg else -2
            tensor_shape.dimension_at_index(
                tensor_shape.TensorShape(self.shape),
                self_dim).assert_is_compatible_with(
                    tensor_shape.TensorShape(rhs.shape)[arg_dim])

            return self._solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
Ejemplo n.º 24
0
 def block_shape(self):
     return tensor_shape.TensorShape(
         self.spectrum.shape)[-self.block_depth:]
Ejemplo n.º 25
0
    def _solve_matmul_internal(self,
                               x,
                               solve_matmul_fn,
                               adjoint=False,
                               adjoint_arg=False):
        # We heavily rely on Roth's column Lemma [1]:
        # (A x B) * vec X = vec BXA^T
        # where vec stacks all the columns of the matrix under each other.
        # In our case, we use a variant of the lemma that is row-major
        # friendly: (A x B) * vec' X = vec' AXB^T
        # Where vec' reshapes a matrix into a vector. We can repeatedly apply this
        # for a collection of kronecker products.
        # Given that (A x B)^-1 = A^-1 x B^-1 and (A x B)^T = A^T x B^T, we can
        # use the above to compute multiplications, solves with any composition of
        # transposes.
        output = x

        if adjoint_arg:
            if np.issubdtype(self.dtype, np.complexfloating):
                output = math_ops.conj(output)
        else:
            output = linalg.transpose(output)

        for o in reversed(self.operators):
            # Statically compute the reshape.
            if adjoint:
                operator_dimension = o.range_dimension_tensor()
            else:
                operator_dimension = o.domain_dimension_tensor()
            output_shape = _prefer_static_shape(output)

            if ops.get_static_value(operator_dimension) is not None:
                operator_dimension = ops.get_static_value(operator_dimension)
                if tensor_shape.TensorShape(
                        output.shape
                )[-2] is not None and tensor_shape.TensorShape(
                        output.shape)[-1] is not None:
                    dim = int(
                        tensor_shape.TensorShape(output.shape)[-2] *
                        output_shape[-1] // operator_dimension)
            else:
                dim = _ops.cast(output_shape[-2] * output_shape[-1] //
                                operator_dimension,
                                dtype=dtypes.int32)

            output_shape = _prefer_static_concat_shape(
                output_shape[:-2], [dim, operator_dimension])
            output = array_ops.reshape(output, shape=output_shape)

            # Conjugate because we are trying to compute A @ B^T, but
            # `LinearOperator` only supports `adjoint_arg`.
            if np.issubdtype(self.dtype, np.complexfloating):
                output = math_ops.conj(output)

            output = solve_matmul_fn(o,
                                     output,
                                     adjoint=adjoint,
                                     adjoint_arg=True)

        if adjoint_arg:
            col_dim = _prefer_static_shape(x)[-2]
        else:
            col_dim = _prefer_static_shape(x)[-1]

        if adjoint:
            row_dim = self.domain_dimension_tensor()
        else:
            row_dim = self.range_dimension_tensor()

        matrix_shape = [row_dim, col_dim]

        output = array_ops.reshape(
            output,
            _prefer_static_concat_shape(
                _prefer_static_shape(output)[:-2], matrix_shape))

        if tensor_shape.TensorShape(x.shape).is_fully_defined():
            if adjoint_arg:
                column_dim = tensor_shape.TensorShape(x.shape)[-2]
            else:
                column_dim = tensor_shape.TensorShape(x.shape)[-1]
            broadcast_batch_shape = common_shapes.broadcast_shape(
                tensor_shape.TensorShape(x.shape)[:-2], self.batch_shape)
            if adjoint:
                matrix_dimensions = [self.domain_dimension, column_dim]
            else:
                matrix_dimensions = [self.range_dimension, column_dim]

            tensorshape_util.set_shape(
                output, broadcast_batch_shape.concatenate(matrix_dimensions))

        return output
Ejemplo n.º 26
0
def _prefer_static_shape(x):
    if tensor_shape.TensorShape(x.shape).is_fully_defined():
        return tensor_shape.TensorShape(x.shape)
    return prefer_static.shape(x)