Beispiel #1
0
  def _shape(self):
    # Get final matrix shape.
    domain_dimension = sum(self._block_domain_dimensions())
    range_dimension = sum(self._block_range_dimensions())
    matrix_shape = tensor_shape.TensorShape([domain_dimension, range_dimension])

    # Get broadcast batch shape.
    # broadcast_shape checks for compatibility.
    batch_shape = self.operators[0].batch_shape
    for operator in self.operators[1:]:
      batch_shape = common_shapes.broadcast_shape(
          batch_shape, operator.batch_shape)

    return batch_shape.concatenate(matrix_shape)
  def _shape(self):
    # Get final matrix shape.
    domain_dimension = self.operators[0].domain_dimension
    for operator in self.operators[1:]:
      domain_dimension.assert_is_compatible_with(operator.range_dimension)
      domain_dimension = operator.domain_dimension

    matrix_shape = tensor_shape.TensorShape(
        [self.operators[0].range_dimension,
         self.operators[-1].domain_dimension])

    # Get broadcast batch shape.
    # broadcast_shape checks for compatibility.
    batch_shape = self.operators[0].batch_shape
    for operator in self.operators[1:]:
      batch_shape = common_shapes.broadcast_shape(
          batch_shape, operator.batch_shape)

    return batch_shape.concatenate(matrix_shape)
    def _solve(self, rhs, adjoint=False, adjoint_arg=False):
        # Here we follow the same use of Roth's column lemma as in `matmul`, with
        # the key difference that we replace all `matmul` instances with `solve`.
        # This follows from the property that inv(A x B) = inv(A) x inv(B).

        # Below we document the shape manipulation for adjoint=False,
        # adjoint_arg=False, but the general case of different adjoints is still
        # handled.

        if adjoint_arg:
            rhs = linalg.adjoint(rhs)

        # Always add a batch dimension to enable broadcasting to work.
        batch_shape = self._compute_ones_matrix_shape()
        rhs = rhs + array_ops.zeros(batch_shape, dtype=rhs.dtype)

        # rhs has shape [B, R, C], where B represent some number of batch
        # dimensions,
        # R represents the number of rows, and C represents the number of columns.
        # In order to apply Roth's column lemma, we need to operate on a batch of
        # column vectors, so we reshape into a batch of column vectors. We put it
        # at the front to ensure that broadcasting between operators to the batch
        # dimensions B still works.
        output = _rotate_last_dim(rhs, rotate_right=True)

        # Also expand the shape to be [A, C, B, R]. The first dimension will be
        # used to accumulate dimensions from each operator matmul.
        output = output[_ops.newaxis, ...]

        # In this loop, A is going to refer to the value of the accumulated
        # dimension. A = 1 at the start, and will end up being self.range_dimension.
        # V will refer to the last dimension. V = R at the start, and will end up
        # being 1 in the end.
        for operator in self.operators[:-1]:
            # Reshape output from [A, C, B, V] to be
            # [A, C, B, V / op.domain_dimension, op.domain_dimension]
            if adjoint:
                operator_dimension = operator.range_dimension_tensor()
            else:
                operator_dimension = operator.domain_dimension_tensor()

            output = _unvec_by(output, operator_dimension)

            # We are computing (XA^-1^T) = (A^-1 X^T)^T.
            # output has [A, C, B, V / op.domain_dimension, op.domain_dimension],
            # which is being converted to:
            # [A, C, B, V / op.domain_dimension, op.range_dimension]
            output = _linalg.matrix_transpose(output)
            output = operator.solve(output, adjoint=adjoint, adjoint_arg=False)
            output = _linalg.matrix_transpose(output)
            # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension]
            output = _rotate_last_dim(output, rotate_right=False)
            output = _vec(output)
            output = _rotate_last_dim(output, rotate_right=True)

        # After the loop, we will have
        # A = self.range_dimension / op[-1].range_dimension
        # V = op[-1].domain_dimension

        # We convert that using matvec to get:
        # [A, C, B, op[-1].range_dimension]
        output = self.operators[-1].solvevec(output, adjoint=adjoint)
        # Rearrange shape to be [B1, ... Bn, self.range_dimension, C]
        output = _rotate_last_dim(output, rotate_right=False)
        output = _vec(output)
        output = _rotate_last_dim(output, rotate_right=False)

        if tensor_shape.TensorShape(rhs.shape).is_fully_defined():
            column_dim = tensor_shape.TensorShape(rhs.shape)[-1]
            broadcast_batch_shape = common_shapes.broadcast_shape(
                tensor_shape.TensorShape(rhs.shape)[:-2], self.batch_shape)
            if adjoint:
                matrix_dimensions = [self.domain_dimension, column_dim]
            else:
                matrix_dimensions = [self.range_dimension, column_dim]

            tensorshape_util.set_shape(
                output, broadcast_batch_shape.concatenate(matrix_dimensions))

        return output
    def _matmul(self, x, adjoint=False, adjoint_arg=False):
        # Here we heavily rely on Roth's column Lemma [1]:
        # (A x B) * vec X = vec BXA^T,
        # where vec stacks all the columns of the matrix under each other. In our
        # case, x represents a batch of vec X (i.e. we think of x as a batch of
        # column vectors, rather than a matrix). Each member of the batch can be
        # reshaped to a matrix (hence we get a batch of matrices).
        # We can iteratively apply this lemma by noting that if B is a Kronecker
        # product, then we can apply the lemma again.

        # [1] W. E. Roth, "On direct product matrices,"
        # Bulletin of the American Mathematical Society, vol. 40, pp. 461-468,
        # 1934

        # Efficiency

        # Naively doing the Kronecker product, by calculating the dense matrix and
        # applying it will can take cubic time in  the size of domain_dimension
        # (assuming a square matrix). The other issue is that calculating the dense
        # matrix can be prohibitively expensive, in that it can take a large amount
        # of memory.
        #
        # This implementation avoids this memory blow up by only computing matmuls
        # with the factors. In this way, we don't have to realize the dense matrix.
        # In terms of complexity, if we have Kronecker Factors of size:
        # (n1, n1), (n2, n2), (n3, n3), ... (nJ, nJ), with N = \prod n_i, and we
        # have as input a [N, M] matrix, the naive approach would take O(N^2 M).
        # With this approach (ignoring reshaping of tensors and transposes for now),
        # the time complexity can be O(M * (\sum n_i) * N). There is also the
        # benefit of batched multiplication (In this example, the batch size is
        # roughly M * N) so this can be much faster. However, not factored in are
        # the costs of the several transposing of tensors, which can affect cache
        # behavior.

        # Below we document the shape manipulation for adjoint=False,
        # adjoint_arg=False, but the general case of different adjoints is still
        # handled.

        if adjoint_arg:
            x = linalg.adjoint(x)

        # Always add a batch dimension to enable broadcasting to work.
        batch_shape = self._compute_ones_matrix_shape()
        x = x + array_ops.zeros(batch_shape, dtype=x.dtype)

        # x has shape [B, R, C], where B represent some number of batch dimensions,
        # R represents the number of rows, and C represents the number of columns.
        # In order to apply Roth's column lemma, we need to operate on a batch of
        # column vectors, so we reshape into a batch of column vectors. We put it
        # at the front to ensure that broadcasting between operators to the batch
        # dimensions B still works.
        output = _rotate_last_dim(x, rotate_right=True)

        # Also expand the shape to be [A, C, B, R]. The first dimension will be
        # used to accumulate dimensions from each operator matmul.
        output = output[_ops.newaxis, ...]

        # In this loop, A is going to refer to the value of the accumulated
        # dimension. A = 1 at the start, and will end up being self.range_dimension.
        # V will refer to the last dimension. V = R at the start, and will end up
        # being 1 in the end.
        for operator in self.operators[:-1]:
            # Reshape output from [A, C, B, V] to be
            # [A, C, B, V / op.domain_dimension, op.domain_dimension]
            if adjoint:
                operator_dimension = operator.range_dimension_tensor()
            else:
                operator_dimension = operator.domain_dimension_tensor()

            output = _unvec_by(output, operator_dimension)

            # We are computing (XA^T) = (AX^T)^T.
            # output has [A, C, B, V / op.domain_dimension, op.domain_dimension],
            # which is being converted to:
            # [A, C, B, V / op.domain_dimension, op.range_dimension]
            output = _linalg.matrix_transpose(output)
            output = operator.matmul(output,
                                     adjoint=adjoint,
                                     adjoint_arg=False)
            output = _linalg.matrix_transpose(output)
            # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension]
            output = _rotate_last_dim(output, rotate_right=False)
            output = _vec(output)
            output = _rotate_last_dim(output, rotate_right=True)

        # After the loop, we will have
        # A = self.range_dimension / op[-1].range_dimension
        # V = op[-1].domain_dimension

        # We convert that using matvec to get:
        # [A, C, B, op[-1].range_dimension]
        output = self.operators[-1].matvec(output, adjoint=adjoint)
        # Rearrange shape to be [B1, ... Bn, self.range_dimension, C]
        output = _rotate_last_dim(output, rotate_right=False)
        output = _vec(output)
        output = _rotate_last_dim(output, rotate_right=False)

        if tensor_shape.TensorShape(x.shape).is_fully_defined():
            column_dim = tensor_shape.TensorShape(x.shape)[-1]
            broadcast_batch_shape = common_shapes.broadcast_shape(
                tensor_shape.TensorShape(x.shape)[:-2], self.batch_shape)
            if adjoint:
                matrix_dimensions = [self.domain_dimension, column_dim]
            else:
                matrix_dimensions = [self.range_dimension, column_dim]

            tensorshape_util.set_shape(
                output, broadcast_batch_shape.concatenate(matrix_dimensions))

        return output
def _bcast_shape(base_shape, args):
    bcast_shape = _ensure_shape_tuple(base_shape)
    for arg in args:
        bcast_shape = ops.broadcast_shape(bcast_shape, np.asarray(arg).shape)
    return bcast_shape
Beispiel #6
0
    def _solve_matmul_internal(self,
                               x,
                               solve_matmul_fn,
                               adjoint=False,
                               adjoint_arg=False):
        # We heavily rely on Roth's column Lemma [1]:
        # (A x B) * vec X = vec BXA^T
        # where vec stacks all the columns of the matrix under each other.
        # In our case, we use a variant of the lemma that is row-major
        # friendly: (A x B) * vec' X = vec' AXB^T
        # Where vec' reshapes a matrix into a vector. We can repeatedly apply this
        # for a collection of kronecker products.
        # Given that (A x B)^-1 = A^-1 x B^-1 and (A x B)^T = A^T x B^T, we can
        # use the above to compute multiplications, solves with any composition of
        # transposes.
        output = x

        if adjoint_arg:
            if np.issubdtype(self.dtype, np.complexfloating):
                output = math_ops.conj(output)
        else:
            output = linalg.transpose(output)

        for o in reversed(self.operators):
            # Statically compute the reshape.
            if adjoint:
                operator_dimension = o.range_dimension_tensor()
            else:
                operator_dimension = o.domain_dimension_tensor()
            output_shape = _prefer_static_shape(output)

            if ops.get_static_value(operator_dimension) is not None:
                operator_dimension = ops.get_static_value(operator_dimension)
                if tensor_shape.TensorShape(
                        output.shape
                )[-2] is not None and tensor_shape.TensorShape(
                        output.shape)[-1] is not None:
                    dim = int(
                        tensor_shape.TensorShape(output.shape)[-2] *
                        output_shape[-1] // operator_dimension)
            else:
                dim = _ops.cast(output_shape[-2] * output_shape[-1] //
                                operator_dimension,
                                dtype=dtypes.int32)

            output_shape = _prefer_static_concat_shape(
                output_shape[:-2], [dim, operator_dimension])
            output = array_ops.reshape(output, shape=output_shape)

            # Conjugate because we are trying to compute A @ B^T, but
            # `LinearOperator` only supports `adjoint_arg`.
            if np.issubdtype(self.dtype, np.complexfloating):
                output = math_ops.conj(output)

            output = solve_matmul_fn(o,
                                     output,
                                     adjoint=adjoint,
                                     adjoint_arg=True)

        if adjoint_arg:
            col_dim = _prefer_static_shape(x)[-2]
        else:
            col_dim = _prefer_static_shape(x)[-1]

        if adjoint:
            row_dim = self.domain_dimension_tensor()
        else:
            row_dim = self.range_dimension_tensor()

        matrix_shape = [row_dim, col_dim]

        output = array_ops.reshape(
            output,
            _prefer_static_concat_shape(
                _prefer_static_shape(output)[:-2], matrix_shape))

        if tensor_shape.TensorShape(x.shape).is_fully_defined():
            if adjoint_arg:
                column_dim = tensor_shape.TensorShape(x.shape)[-2]
            else:
                column_dim = tensor_shape.TensorShape(x.shape)[-1]
            broadcast_batch_shape = common_shapes.broadcast_shape(
                tensor_shape.TensorShape(x.shape)[:-2], self.batch_shape)
            if adjoint:
                matrix_dimensions = [self.domain_dimension, column_dim]
            else:
                matrix_dimensions = [self.range_dimension, column_dim]

            tensorshape_util.set_shape(
                output, broadcast_batch_shape.concatenate(matrix_dimensions))

        return output