示例#1
0
    def _to_dense(self):
        num_cols = 0
        rows = []
        broadcasted_blocks = [
            operator.to_dense() for operator in self.operators
        ]
        broadcasted_blocks = linear_operator_util.broadcast_matrix_batch_dims(
            broadcasted_blocks)
        for block in broadcasted_blocks:
            batch_row_shape = array_ops.shape(block)[:-1]

            zeros_to_pad_before_shape = array_ops.concat(
                [batch_row_shape, [num_cols]], axis=-1)
            zeros_to_pad_before = array_ops.zeros(
                shape=zeros_to_pad_before_shape, dtype=block.dtype)
            num_cols += array_ops.shape(block)[-1]
            zeros_to_pad_after_shape = array_ops.concat(
                [batch_row_shape, [self.domain_dimension_tensor() - num_cols]],
                axis=-1)
            zeros_to_pad_after = array_ops.zeros(
                shape=zeros_to_pad_after_shape, dtype=block.dtype)

            rows.append(
                array_ops.concat(
                    [zeros_to_pad_before, block, zeros_to_pad_after], axis=-1))

        mat = array_ops.concat(rows, axis=-2)
        tensorshape_util.set_shape(mat, tensor_shape.TensorShape(self.shape))
        return mat
    def _possibly_broadcast_batch_shape(self, x):
        """Return 'x', possibly after broadcasting the leading dimensions."""
        # If we have no batch shape, our batch shape broadcasts with everything!
        if self._batch_shape_arg is None:
            return x

        # Static attempt:
        #   If we determine that no broadcast is necessary, pass x through
        #   If we need a broadcast, add to an array of zeros.
        #
        # special_shape is the shape that, when broadcast with x's shape, will give
        # the correct broadcast_shape.  Note that
        #   We have already verified the second to last dimension of tensor_shape.TensorShape(self.shape)
        #   matches x's shape in assert_compatible_matrix_dimensions.
        #   Also, the final dimension of 'x' can have any shape.
        #   Therefore, the final two dimensions of special_shape are 1's.
        special_shape = self.batch_shape.concatenate([1, 1])
        bshape = _ops.broadcast_static_shape(tensor_shape.TensorShape(x.shape),
                                             special_shape)
        if special_shape.is_fully_defined():
            # bshape.is_fully_defined iff special_shape.is_fully_defined.
            if bshape == tensor_shape.TensorShape(x.shape):
                return x
            # Use the built in broadcasting of addition.
            zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype)
            return x + zeros

        # Dynamic broadcast:
        #   Always add to an array of zeros, rather than using a "cond", since a
        #   cond would require copying data from GPU --> CPU.
        special_shape = array_ops.concat((self.batch_shape_tensor(), [1, 1]),
                                         0)
        zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype)
        return x + zeros
示例#3
0
 def _trace(self):
     # Get Tensor of all zeros of same shape as self.batch_shape.
     if self.batch_shape.is_fully_defined():
         return array_ops.zeros(shape=self.batch_shape, dtype=self.dtype)
     else:
         return array_ops.zeros(shape=self.batch_shape_tensor(),
                                dtype=self.dtype)
示例#4
0
    def _matmul(self, x, adjoint=False, adjoint_arg=False):
        if self._assert_proper_shapes:
            x = linalg.adjoint(x) if adjoint_arg else x
            aps = linear_operator_util.assert_compatible_matrix_dimensions(
                self, x)
            x = distribution_util.with_dependencies([aps], x)
        if self.is_square:
            # Note that adjoint has no effect since this matrix is self-adjoint.
            if adjoint_arg:
                output_shape = prefer_static.concat([
                    prefer_static.shape(x)[:-2],
                    [prefer_static.shape(x)[-1],
                     prefer_static.shape(x)[-2]]
                ],
                                                    axis=0)
            else:
                output_shape = prefer_static.shape(x)

            return self._possibly_broadcast_batch_shape(
                array_ops.zeros(shape=output_shape, dtype=x.dtype))

        x_shape = prefer_static.shape(x)
        n = self._num_columns if adjoint else self._num_rows
        m = x_shape[-2] if adjoint_arg else x_shape[-1]

        output_shape = prefer_static.concat([x_shape[:-2], [n, m]], axis=0)

        zeros = array_ops.zeros(shape=output_shape, dtype=x.dtype)
        return self._possibly_broadcast_batch_shape(zeros)
示例#5
0
  def _shape_tensor(self):
    # Avoid messy broadcasting if possible.
    if tensor_shape.TensorShape(self.shape).is_fully_defined():
      return ops.convert_to_tensor(
          tensor_shape.TensorShape(self.shape).as_list(), dtype=dtypes.int32, name="shape")

    domain_dimension = sum(self._block_domain_dimension_tensors())
    range_dimension = sum(self._block_range_dimension_tensors())
    matrix_shape = array_ops.stack([domain_dimension, range_dimension])

    # Dummy Tensor of zeros.  Will never be materialized.
    zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor())
    for operator in self.operators[1:]:
      zeros = zeros + array_ops.zeros(shape=operator.batch_shape_tensor())
    batch_shape = array_ops.shape(zeros)

    return array_ops.concat((batch_shape, matrix_shape), 0)
示例#6
0
    def _zeros_diag(self):
        """Returns the diagonal of this operator as all zeros."""
        if tensor_shape.TensorShape(self.shape).is_fully_defined():
            d_shape = self.batch_shape.concatenate([self._min_matrix_dim()])
        else:
            d_shape = prefer_static.concat(
                [self.batch_shape_tensor(), [self._min_matrix_dim_tensor()]],
                axis=0)

        return array_ops.zeros(shape=d_shape, dtype=self.dtype)
  def _shape_tensor(self):
    # Avoid messy broadcasting if possible.
    if tensor_shape.TensorShape(self.shape).is_fully_defined():
      return ops.convert_to_tensor(
          tensor_shape.TensorShape(self.shape).as_list(), dtype=dtypes.int32, name="shape")

    # Don't check the matrix dimensions.  That would add unnecessary Asserts to
    # the graph.  Things will fail at runtime naturally if shapes are
    # incompatible.
    matrix_shape = array_ops.stack([
        self.operators[0].range_dimension_tensor(),
        self.operators[-1].domain_dimension_tensor()
    ])

    # Dummy Tensor of zeros.  Will never be materialized.
    zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor())
    for operator in self.operators[1:]:
      zeros += array_ops.zeros(shape=operator.batch_shape_tensor())
    batch_shape = array_ops.shape(zeros)

    return array_ops.concat((batch_shape, matrix_shape), 0)
  def _to_dense(self):
    num_cols = 0
    dense_rows = []
    flat_broadcast_operators = linear_operator_util.broadcast_matrix_batch_dims(
        [op.to_dense() for row in self.operators for op in row])  # pylint: disable=g-complex-comprehension
    broadcast_operators = [
        flat_broadcast_operators[i * (i + 1) // 2:(i + 1) * (i + 2) // 2]
        for i in range(len(self.operators))]
    for row_blocks in broadcast_operators:
      batch_row_shape = prefer_static.shape(row_blocks[0])[:-1]
      num_cols = num_cols + prefer_static.shape(row_blocks[-1])[-1]
      zeros_to_pad_after_shape = prefer_static.concat(
          [batch_row_shape,
           [self.domain_dimension_tensor() - num_cols]], axis=-1)
      zeros_to_pad_after = array_ops.zeros(
          shape=zeros_to_pad_after_shape, dtype=self.dtype)

      row_blocks.append(zeros_to_pad_after)
      dense_rows.append(prefer_static.concat(row_blocks, axis=-1))

    mat = prefer_static.concat(dense_rows, axis=-2)
    tensorshape_util.set_shape(mat, tensor_shape.TensorShape(self.shape))
    return mat
    def _solve(self, rhs, adjoint=False, adjoint_arg=False):
        # Here we follow the same use of Roth's column lemma as in `matmul`, with
        # the key difference that we replace all `matmul` instances with `solve`.
        # This follows from the property that inv(A x B) = inv(A) x inv(B).

        # Below we document the shape manipulation for adjoint=False,
        # adjoint_arg=False, but the general case of different adjoints is still
        # handled.

        if adjoint_arg:
            rhs = linalg.adjoint(rhs)

        # Always add a batch dimension to enable broadcasting to work.
        batch_shape = self._compute_ones_matrix_shape()
        rhs = rhs + array_ops.zeros(batch_shape, dtype=rhs.dtype)

        # rhs has shape [B, R, C], where B represent some number of batch
        # dimensions,
        # R represents the number of rows, and C represents the number of columns.
        # In order to apply Roth's column lemma, we need to operate on a batch of
        # column vectors, so we reshape into a batch of column vectors. We put it
        # at the front to ensure that broadcasting between operators to the batch
        # dimensions B still works.
        output = _rotate_last_dim(rhs, rotate_right=True)

        # Also expand the shape to be [A, C, B, R]. The first dimension will be
        # used to accumulate dimensions from each operator matmul.
        output = output[_ops.newaxis, ...]

        # In this loop, A is going to refer to the value of the accumulated
        # dimension. A = 1 at the start, and will end up being self.range_dimension.
        # V will refer to the last dimension. V = R at the start, and will end up
        # being 1 in the end.
        for operator in self.operators[:-1]:
            # Reshape output from [A, C, B, V] to be
            # [A, C, B, V / op.domain_dimension, op.domain_dimension]
            if adjoint:
                operator_dimension = operator.range_dimension_tensor()
            else:
                operator_dimension = operator.domain_dimension_tensor()

            output = _unvec_by(output, operator_dimension)

            # We are computing (XA^-1^T) = (A^-1 X^T)^T.
            # output has [A, C, B, V / op.domain_dimension, op.domain_dimension],
            # which is being converted to:
            # [A, C, B, V / op.domain_dimension, op.range_dimension]
            output = _linalg.matrix_transpose(output)
            output = operator.solve(output, adjoint=adjoint, adjoint_arg=False)
            output = _linalg.matrix_transpose(output)
            # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension]
            output = _rotate_last_dim(output, rotate_right=False)
            output = _vec(output)
            output = _rotate_last_dim(output, rotate_right=True)

        # After the loop, we will have
        # A = self.range_dimension / op[-1].range_dimension
        # V = op[-1].domain_dimension

        # We convert that using matvec to get:
        # [A, C, B, op[-1].range_dimension]
        output = self.operators[-1].solvevec(output, adjoint=adjoint)
        # Rearrange shape to be [B1, ... Bn, self.range_dimension, C]
        output = _rotate_last_dim(output, rotate_right=False)
        output = _vec(output)
        output = _rotate_last_dim(output, rotate_right=False)

        if tensor_shape.TensorShape(rhs.shape).is_fully_defined():
            column_dim = tensor_shape.TensorShape(rhs.shape)[-1]
            broadcast_batch_shape = common_shapes.broadcast_shape(
                tensor_shape.TensorShape(rhs.shape)[:-2], self.batch_shape)
            if adjoint:
                matrix_dimensions = [self.domain_dimension, column_dim]
            else:
                matrix_dimensions = [self.range_dimension, column_dim]

            tensorshape_util.set_shape(
                output, broadcast_batch_shape.concatenate(matrix_dimensions))

        return output
    def _matmul(self, x, adjoint=False, adjoint_arg=False):
        # Here we heavily rely on Roth's column Lemma [1]:
        # (A x B) * vec X = vec BXA^T,
        # where vec stacks all the columns of the matrix under each other. In our
        # case, x represents a batch of vec X (i.e. we think of x as a batch of
        # column vectors, rather than a matrix). Each member of the batch can be
        # reshaped to a matrix (hence we get a batch of matrices).
        # We can iteratively apply this lemma by noting that if B is a Kronecker
        # product, then we can apply the lemma again.

        # [1] W. E. Roth, "On direct product matrices,"
        # Bulletin of the American Mathematical Society, vol. 40, pp. 461-468,
        # 1934

        # Efficiency

        # Naively doing the Kronecker product, by calculating the dense matrix and
        # applying it will can take cubic time in  the size of domain_dimension
        # (assuming a square matrix). The other issue is that calculating the dense
        # matrix can be prohibitively expensive, in that it can take a large amount
        # of memory.
        #
        # This implementation avoids this memory blow up by only computing matmuls
        # with the factors. In this way, we don't have to realize the dense matrix.
        # In terms of complexity, if we have Kronecker Factors of size:
        # (n1, n1), (n2, n2), (n3, n3), ... (nJ, nJ), with N = \prod n_i, and we
        # have as input a [N, M] matrix, the naive approach would take O(N^2 M).
        # With this approach (ignoring reshaping of tensors and transposes for now),
        # the time complexity can be O(M * (\sum n_i) * N). There is also the
        # benefit of batched multiplication (In this example, the batch size is
        # roughly M * N) so this can be much faster. However, not factored in are
        # the costs of the several transposing of tensors, which can affect cache
        # behavior.

        # Below we document the shape manipulation for adjoint=False,
        # adjoint_arg=False, but the general case of different adjoints is still
        # handled.

        if adjoint_arg:
            x = linalg.adjoint(x)

        # Always add a batch dimension to enable broadcasting to work.
        batch_shape = self._compute_ones_matrix_shape()
        x = x + array_ops.zeros(batch_shape, dtype=x.dtype)

        # x has shape [B, R, C], where B represent some number of batch dimensions,
        # R represents the number of rows, and C represents the number of columns.
        # In order to apply Roth's column lemma, we need to operate on a batch of
        # column vectors, so we reshape into a batch of column vectors. We put it
        # at the front to ensure that broadcasting between operators to the batch
        # dimensions B still works.
        output = _rotate_last_dim(x, rotate_right=True)

        # Also expand the shape to be [A, C, B, R]. The first dimension will be
        # used to accumulate dimensions from each operator matmul.
        output = output[_ops.newaxis, ...]

        # In this loop, A is going to refer to the value of the accumulated
        # dimension. A = 1 at the start, and will end up being self.range_dimension.
        # V will refer to the last dimension. V = R at the start, and will end up
        # being 1 in the end.
        for operator in self.operators[:-1]:
            # Reshape output from [A, C, B, V] to be
            # [A, C, B, V / op.domain_dimension, op.domain_dimension]
            if adjoint:
                operator_dimension = operator.range_dimension_tensor()
            else:
                operator_dimension = operator.domain_dimension_tensor()

            output = _unvec_by(output, operator_dimension)

            # We are computing (XA^T) = (AX^T)^T.
            # output has [A, C, B, V / op.domain_dimension, op.domain_dimension],
            # which is being converted to:
            # [A, C, B, V / op.domain_dimension, op.range_dimension]
            output = _linalg.matrix_transpose(output)
            output = operator.matmul(output,
                                     adjoint=adjoint,
                                     adjoint_arg=False)
            output = _linalg.matrix_transpose(output)
            # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension]
            output = _rotate_last_dim(output, rotate_right=False)
            output = _vec(output)
            output = _rotate_last_dim(output, rotate_right=True)

        # After the loop, we will have
        # A = self.range_dimension / op[-1].range_dimension
        # V = op[-1].domain_dimension

        # We convert that using matvec to get:
        # [A, C, B, op[-1].range_dimension]
        output = self.operators[-1].matvec(output, adjoint=adjoint)
        # Rearrange shape to be [B1, ... Bn, self.range_dimension, C]
        output = _rotate_last_dim(output, rotate_right=False)
        output = _vec(output)
        output = _rotate_last_dim(output, rotate_right=False)

        if tensor_shape.TensorShape(x.shape).is_fully_defined():
            column_dim = tensor_shape.TensorShape(x.shape)[-1]
            broadcast_batch_shape = common_shapes.broadcast_shape(
                tensor_shape.TensorShape(x.shape)[:-2], self.batch_shape)
            if adjoint:
                matrix_dimensions = [self.domain_dimension, column_dim]
            else:
                matrix_dimensions = [self.range_dimension, column_dim]

            tensorshape_util.set_shape(
                output, broadcast_batch_shape.concatenate(matrix_dimensions))

        return output
 def _log_abs_determinant(self):
   # Orthogonal matrix -> log|Q| = 0.
   return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype)
 def _log_abs_determinant(self):
     return array_ops.zeros(shape=self.batch_shape_tensor(),
                            dtype=self.dtype)
示例#13
0
 def _determinant(self):
     if self.batch_shape.is_fully_defined():
         return array_ops.zeros(shape=self.batch_shape, dtype=self.dtype)
     else:
         return array_ops.zeros(shape=self.batch_shape_tensor(),
                                dtype=self.dtype)