Пример #1
0
    def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
        """Transform [batch] matrix `x` with left multiplication:  `x --> Ax`.

    ```python
    # Make an operator acting like batch matrix A.  Assume tensor_shape.TensorShape(A.shape) = [..., M, N]
    operator = LinearOperator(...)
    tensor_shape.TensorShape(operator.shape) = [..., M, N]

    X = ... # shape [..., N, R], batch matrix, R > 0.

    Y = operator.matmul(X)
    tensor_shape.TensorShape(Y.shape)
    ==> [..., M, R]

    Y[..., :, r] = sum_j A[..., :, j] X[j, r]
    ```

    Args:
      x: `LinearOperator` or `Tensor` with compatible shape and same `dtype` as
        `self`. See class docstring for definition of compatibility.
      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
      adjoint_arg:  Python `bool`.  If `True`, compute `A x^H` where `x^H` is
        the hermitian transpose (transposition and complex conjugation).
      name:  A name for this `Op`.

    Returns:
      A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype`
        as `self`.
    """
        if isinstance(x, LinearOperator):
            left_operator = self.adjoint() if adjoint else self
            right_operator = x.adjoint() if adjoint_arg else x

            if (right_operator.range_dimension is not None
                    and left_operator.domain_dimension is not None
                    and right_operator.range_dimension !=
                    left_operator.domain_dimension):
                raise ValueError(
                    "Operators are incompatible. Expected `x` to have dimension"
                    " {} but got {}.".format(left_operator.domain_dimension,
                                             right_operator.range_dimension))
            with self._name_scope(name):
                return linear_operator_algebra.matmul(left_operator,
                                                      right_operator)

        with self._name_scope(name):
            x = ops.convert_to_tensor(x, name="x")
            # self._check_input_dtype(x)

            self_dim = -2 if adjoint else -1
            arg_dim = -1 if adjoint_arg else -2
            tensor_shape.dimension_at_index(
                tensor_shape.TensorShape(self.shape),
                self_dim).assert_is_compatible_with(
                    tensor_shape.TensorShape(x.shape)[arg_dim])

            return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
Пример #2
0
    def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
        """Transform [batch] matrix `x` with left multiplication:  `x --> Ax`.

    ```python
    # Make an operator acting like batch matrix A.  Assume tensor_shape.TensorShape(A.shape) = [..., M, N]
    operator = LinearOperator(...)
    tensor_shape.TensorShape(operator.shape) = [..., M, N]

    X = ... # shape [..., N, R], batch matrix, R > 0.

    Y = operator.matmul(X)
    tensor_shape.TensorShape(Y.shape)
    ==> [..., M, R]

    Y[..., :, r] = sum_j A[..., :, j] X[j, r]
    ```

    Args:
      x: `LinearOperator`, `Tensor` with compatible shape and same `dtype` as
        `self`, or a blockwise iterable of `LinearOperator`s or `Tensor`s. See
        class docstring for definition of shape compatibility.
      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
      adjoint_arg:  Python `bool`.  If `True`, compute `A x^H` where `x^H` is
        the hermitian transpose (transposition and complex conjugation).
      name:  A name for this `Op`.

    Returns:
      A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype`
        as `self`, or if `x` is blockwise, a list of `Tensor`s with shapes that
        concatenate to `[..., M, R]`.
    """
        if isinstance(x, linear_operator.LinearOperator):
            left_operator = self.adjoint() if adjoint else self
            right_operator = x.adjoint() if adjoint_arg else x

            if (right_operator.range_dimension is not None
                    and left_operator.domain_dimension is not None
                    and right_operator.range_dimension !=
                    left_operator.domain_dimension):
                raise ValueError(
                    "Operators are incompatible. Expected `x` to have dimension"
                    " {} but got {}.".format(left_operator.domain_dimension,
                                             right_operator.range_dimension))
            with self._name_scope(name):
                return linear_operator_algebra.matmul(left_operator,
                                                      right_operator)

        with self._name_scope(name):
            arg_dim = -1 if adjoint_arg else -2
            block_dimensions = (self._block_range_dimensions() if adjoint else
                                self._block_domain_dimensions())
            if linear_operator_util.arg_is_blockwise(block_dimensions, x,
                                                     arg_dim):
                for i, block in enumerate(x):
                    if not isinstance(block, linear_operator.LinearOperator):
                        block = ops.convert_to_tensor(block)
                        # self._check_input_dtype(block)
                        block_dimensions[i].assert_is_compatible_with(
                            tensor_shape.TensorShape(block.shape)[arg_dim])
                        x[i] = block
            else:
                x = ops.convert_to_tensor(x, name="x")
                # self._check_input_dtype(x)
                op_dimension = (self.range_dimension
                                if adjoint else self.domain_dimension)
                op_dimension.assert_is_compatible_with(
                    tensor_shape.TensorShape(x.shape)[arg_dim])
            return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)