Example #1
0
    def _matmul(self, x, adjoint=False, adjoint_arg=False):
        if self._assert_proper_shapes:
            x = linalg.adjoint(x) if adjoint_arg else x
            aps = linear_operator_util.assert_compatible_matrix_dimensions(
                self, x)
            x = control_flow_ops.with_dependencies([aps], x)
        if self.is_square:
            # Note that adjoint has no effect since this matrix is self-adjoint.
            if adjoint_arg:
                output_shape = array_ops.concat([
                    array_ops.shape(x)[:-2],
                    [array_ops.shape(x)[-1],
                     array_ops.shape(x)[-2]]
                ],
                                                axis=0)
            else:
                output_shape = array_ops.shape(x)

            return self._possibly_broadcast_batch_shape(
                array_ops.zeros(shape=output_shape, dtype=x.dtype))

        x_shape = array_ops.shape(x)
        n = self._num_columns if adjoint else self._num_rows
        m = x_shape[-2] if adjoint_arg else x_shape[-1]

        output_shape = array_ops.concat([x_shape[:-2], [n, m]], axis=0)

        zeros = array_ops.zeros(shape=output_shape, dtype=x.dtype)
        return self._possibly_broadcast_batch_shape(zeros)
Example #2
0
    def _to_dense(self):
        num_cols = 0
        rows = []
        broadcasted_blocks = [
            operator.to_dense() for operator in self.operators
        ]
        broadcasted_blocks = linear_operator_util.broadcast_matrix_batch_dims(
            broadcasted_blocks)
        for block in broadcasted_blocks:
            batch_row_shape = array_ops.shape(block)[:-1]

            zeros_to_pad_before_shape = array_ops.concat(
                [batch_row_shape, [num_cols]], axis=-1)
            zeros_to_pad_before = array_ops.zeros(
                shape=zeros_to_pad_before_shape, dtype=block.dtype)
            num_cols += array_ops.shape(block)[-1]
            zeros_to_pad_after_shape = array_ops.concat(
                [batch_row_shape, [self.domain_dimension_tensor() - num_cols]],
                axis=-1)
            zeros_to_pad_after = array_ops.zeros(
                shape=zeros_to_pad_after_shape, dtype=block.dtype)

            rows.append(
                array_ops.concat(
                    [zeros_to_pad_before, block, zeros_to_pad_after], axis=-1))

        mat = array_ops.concat(rows, axis=-2)
        mat.set_shape(_ops.TensorShape(self.shape))
        return mat
Example #3
0
def _rotate_last_dim(x, rotate_right=False):
  """Rotate the last dimension either left or right."""
  ndims = array_ops.rank(x)
  if rotate_right:
    transpose_perm = array_ops.concat(
        [[ndims - 1], math_ops.range(0, ndims - 1)], axis=0)
  else:
    transpose_perm = array_ops.concat(
        [math_ops.range(1, ndims), [0]], axis=0)
  return array_ops.transpose(x, transpose_perm)
Example #4
0
    def _possibly_broadcast_batch_shape(self, x):
        """Return 'x', possibly after broadcasting the leading dimensions."""
        # If we have no batch shape, our batch shape broadcasts with everything!
        if self._batch_shape_arg is None:
            return x

        # Static attempt:
        #   If we determine that no broadcast is necessary, pass x through
        #   If we need a broadcast, add to an array of zeros.
        #
        # special_shape is the shape that, when broadcast with x's shape, will give
        # the correct broadcast_shape.  Note that
        #   We have already verified the second to last dimension of _ops.TensorShape(self.shape)
        #   matches x's shape in assert_compatible_matrix_dimensions.
        #   Also, the final dimension of 'x' can have any shape.
        #   Therefore, the final two dimensions of special_shape are 1's.
        special_shape = self.batch_shape.concatenate([1, 1])
        bshape = _ops.broadcast_static_shape_as_tensorshape(
            _ops.TensorShape(x.shape), special_shape)
        if special_shape.is_fully_defined():
            # bshape.is_fully_defined iff special_shape.is_fully_defined.
            if bshape == _ops.TensorShape(x.shape):
                return x
            # Use the built in broadcasting of addition.
            zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype)
            return x + zeros

        # Dynamic broadcast:
        #   Always add to an array of zeros, rather than using a "cond", since a
        #   cond would require copying data from GPU --> CPU.
        special_shape = array_ops.concat((self.batch_shape_tensor(), [1, 1]),
                                         0)
        zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype)
        return x + zeros
Example #5
0
def _unvec_by(y, num_col):
  """Unstack vector to form a matrix, with a specified amount of columns."""
  return _linalg.matrix_transpose(
      array_ops.reshape(
          y,
          array_ops.concat(
              [array_ops.shape(y)[:-1], [num_col, -1]], axis=0)))
Example #6
0
    def _vectorize_then_blockify(self, matrix):
        """Shape batch matrix to batch vector, then blockify trailing dimensions."""
        # Suppose
        #   _ops.TensorShape(matrix.shape) = [m0, m1, m2, m3],
        # and matrix is a matrix because the final two dimensions are matrix dims.
        #   self.block_depth = 2,
        #   self.block_shape = [b0, b1]  (note b0 * b1 = m2).
        # We will reshape matrix to
        #   [m3, m0, m1, b0, b1].

        # Vectorize: Reshape to batch vector.
        #   [m0, m1, m2, m3] --> [m3, m0, m1, m2]
        # This is called "vectorize" because we have taken the final two matrix dims
        # and turned this into a size m3 batch of vectors.
        vec = distribution_util.rotate_transpose(matrix, shift=1)

        # Blockify: Blockfy trailing dimensions.
        #   [m3, m0, m1, m2] --> [m3, m0, m1, b0, b1]
        if (_ops.TensorShape(vec.shape).is_fully_defined()
                and self.block_shape.is_fully_defined()):
            # vec_leading_shape = [m3, m0, m1],
            # the parts of vec that will not be blockified.
            vec_leading_shape = _ops.TensorShape(vec.shape)[:-1]
            final_shape = vec_leading_shape.concatenate(self.block_shape)
        else:
            vec_leading_shape = array_ops.shape(vec)[:-1]
            final_shape = array_ops.concat(
                (vec_leading_shape, self.block_shape_tensor()), 0)
        return array_ops.reshape(vec, final_shape)
Example #7
0
    def _shape_tensor(self):
        matrix_shape = array_ops.stack((self._num_rows, self._num_columns),
                                       axis=0)
        if self._batch_shape_arg is None:
            return matrix_shape

        return array_ops.concat((self._batch_shape_arg, matrix_shape), 0)
Example #8
0
 def _shape_tensor(self):
     # See _ops.TensorShape(self.shape) for explanation of steps
     s_shape = array_ops.shape(self._spectrum)
     batch_shape = s_shape[:-self.block_depth]
     trailing_dims = s_shape[-self.block_depth:]
     n = math_ops.reduce_prod(trailing_dims)
     n_x_n = [n, n]
     return array_ops.concat((batch_shape, n_x_n), 0)
Example #9
0
 def _diag_part(self):
     diag_list = []
     for operator in self.operators:
         # Extend the axis for broadcasting.
         diag_list += [operator.diag_part()[..., array_ops.newaxis]]
     diag_list = linear_operator_util.broadcast_matrix_batch_dims(diag_list)
     diagonal = array_ops.concat(diag_list, axis=-2)
     return array_ops.squeeze(diagonal, axis=-1)
  def __init__(self,
               col,
               row,
               is_non_singular=None,
               is_self_adjoint=None,
               is_positive_definite=None,
               is_square=None,
               name="LinearOperatorToeplitz"):
    r"""Initialize a `LinearOperatorToeplitz`.

    Args:
      col: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`.
        The first column of the operator. Allowed dtypes: `float16`, `float32`,
          `float64`, `complex64`, `complex128`. Note that the first entry of
          `col` is assumed to be the same as the first entry of `row`.
      row: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`.
        The first row of the operator. Allowed dtypes: `float16`, `float32`,
          `float64`, `complex64`, `complex128`. Note that the first entry of
          `row` is assumed to be the same as the first entry of `col`.
      is_non_singular:  Expect that this operator is non-singular.
      is_self_adjoint:  Expect that this operator is equal to its hermitian
        transpose.  If `diag.dtype` is real, this is auto-set to `True`.
      is_positive_definite:  Expect that this operator is positive definite,
        meaning the quadratic form `x^H A x` has positive real part for all
        nonzero `x`.  Note that we do not require the operator to be
        self-adjoint to be positive-definite.  See:
        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
      is_square:  Expect that this operator acts like square [batch] matrices.
      name: A name for this `LinearOperator`.
    """

    with ops.name_scope(name, values=[row, col]):
      self._row = ops.convert_to_tensor(row, name="row")
      self._col = ops.convert_to_tensor(col, name="col")
      self._check_row_col(self._row, self._col)

      circulant_col = array_ops.concat(
          [self._col,
           array_ops.zeros_like(self._col[..., 0:1]),
           array_ops.reverse(self._row[..., 1:], axis=[-1])], axis=-1)

      # To be used for matmul.
      self._circulant = linear_operator_circulant.LinearOperatorCirculant(
          fft_ops.fft(_to_complex(circulant_col)),
          input_output_dtype=self._row.dtype)

      if is_square is False:  # pylint:disable=g-bool-id-comparison
        raise ValueError("Only square Toeplitz operators currently supported.")
      is_square = True

      super(LinearOperatorToeplitz, self).__init__(
          dtype=self._row.dtype,
          graph_parents=[self._row, self._col],
          is_non_singular=is_non_singular,
          is_self_adjoint=is_self_adjoint,
          is_positive_definite=is_positive_definite,
          is_square=is_square,
          name=name)
Example #11
0
    def _broadcast_batch_dims(self, x, spectrum):
        """Broadcast batch dims of batch matrix `x` and spectrum."""
        # _ops.TensorShape(spectrum.shape) = batch_shape + block_shape
        # First make spectrum a batch matrix with
        #   _ops.TensorShape(spectrum.shape) = batch_shape + [prod(block_shape), 1]
        spec_mat = array_ops.reshape(
            spectrum,
            array_ops.concat((self.batch_shape_tensor(), [-1, 1]), axis=0))
        # Second, broadcast, possibly requiring an addition of array of zeros.
        x, spec_mat = linear_operator_util.broadcast_matrix_batch_dims(
            (x, spec_mat))
        # Third, put the block shape back into spectrum.
        batch_shape = array_ops.shape(x)[:-2]
        spectrum = array_ops.reshape(
            spec_mat,
            array_ops.concat((batch_shape, self.block_shape_tensor()), axis=0))

        return x, spectrum
Example #12
0
 def reshape_inv(y):
     # Expand the extra dims hanging off the end, "b_extra_sh".
     # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y
     # Could have different batch dims than a and b, because of broadcasting.
     y_extra_shape = array_ops.concat(
         (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0)
     y_extra_on_end = array_ops.reshape(y, y_extra_shape)
     inverse_perm = np.argsort(perm)
     return array_ops.transpose(y_extra_on_end, perm=inverse_perm)
Example #13
0
    def _zeros_diag(self):
        """Returns the diagonal of this operator as all zeros."""
        if _ops.TensorShape(self.shape).is_fully_defined():
            d_shape = self.batch_shape.concatenate([self._min_matrix_dim()])
        else:
            d_shape = array_ops.concat(
                [self.batch_shape_tensor(), [self._min_matrix_dim_tensor()]],
                axis=0)

        return array_ops.zeros(shape=d_shape, dtype=self.dtype)
Example #14
0
    def _matmul(self, x, adjoint=False, adjoint_arg=False):
        split_dim = -1 if adjoint_arg else -2
        # Split input by rows normally, and otherwise columns.
        split_x = self._split_input_into_blocks(x, axis=split_dim)

        result_list = []
        for index, operator in enumerate(self.operators):
            result_list += [
                operator.matmul(split_x[index],
                                adjoint=adjoint,
                                adjoint_arg=adjoint_arg)
            ]
        result_list = linear_operator_util.broadcast_matrix_batch_dims(
            result_list)
        return array_ops.concat(result_list, axis=-2)
Example #15
0
    def _solve(self, rhs, adjoint=False, adjoint_arg=False):
        split_dim = -1 if adjoint_arg else -2
        # Split input by rows normally, and otherwise columns.
        split_rhs = self._split_input_into_blocks(rhs, axis=split_dim)

        solution_list = []
        for index, operator in enumerate(self.operators):
            solution_list += [
                operator.solve(split_rhs[index],
                               adjoint=adjoint,
                               adjoint_arg=adjoint_arg)
            ]

        solution_list = linear_operator_util.broadcast_matrix_batch_dims(
            solution_list)
        return array_ops.concat(solution_list, axis=-2)
Example #16
0
 def _diag_part(self):
   diag_part = self.operators[0].diag_part()
   for operator in self.operators[1:]:
     diag_part = diag_part[..., :, array_ops.newaxis]
     op_diag_part = operator.diag_part()[..., array_ops.newaxis, :]
     diag_part *= op_diag_part
     diag_part = array_ops.reshape(
         diag_part,
         shape=array_ops.concat(
             [array_ops.shape(diag_part)[:-2], [-1]], axis=0))
   if self.range_dimension > self.domain_dimension:
     diag_dimension = self.domain_dimension
   else:
     diag_dimension = self.range_dimension
   diag_part.set_shape(
       self.batch_shape.concatenate(diag_dimension))
   return diag_part
Example #17
0
  def _shape_tensor(self):
    domain_dimension = self.operators[0].domain_dimension_tensor()
    for operator in self.operators[1:]:
      domain_dimension *= operator.domain_dimension_tensor()

    range_dimension = self.operators[0].range_dimension_tensor()
    for operator in self.operators[1:]:
      range_dimension *= operator.range_dimension_tensor()

    matrix_shape = [range_dimension, domain_dimension]

    # Get broadcast batch shape.
    # broadcast_shape checks for compatibility.
    batch_shape = self.operators[0].batch_shape_tensor()
    for operator in self.operators[1:]:
      batch_shape = array_ops.broadcast_dynamic_shape(
          batch_shape, operator.batch_shape_tensor())

    return array_ops.concat((batch_shape, matrix_shape), 0)
  def _matmul(self, x, adjoint=False, adjoint_arg=False):
    # Given a Toeplitz matrix, we can embed it in a Circulant matrix to perform
    # efficient matrix multiplications. Given a Toeplitz matrix with first row
    # [t_0, t_1, ... t_{n-1}] and first column [t0, t_{-1}, ..., t_{-(n-1)},
    # let C by the circulant matrix with first column [t0, t_{-1}, ...,
    # t_{-(n-1)}, 0, t_{n-1}, ..., t_1]. Also adjoin to our input vector `x`
    # `n` zeros, to make it a vector of length `2n` (call it y). It can be shown
    # that if we take the first n entries of `Cy`, this is equal to the Toeplitz
    # multiplication. See:
    # http://math.mit.edu/icg/resources/teaching/18.085-spring2015/toeplitz.pdf
    # for more details.
    x = linalg.adjoint(x) if adjoint_arg else x
    expanded_x = array_ops.concat([x, array_ops.zeros_like(x)], axis=-2)
    result = self._circulant.matmul(
        expanded_x, adjoint=adjoint, adjoint_arg=False)

    return _ops.cast(
        result[..., :self.domain_dimension_tensor(), :],
        self.dtype)
Example #19
0
 def _to_dense(self):
   product = self.operators[0].to_dense()
   for operator in self.operators[1:]:
     # Product has shape [B, R1, 1, C1].
     product = product[
         ..., :, array_ops.newaxis, :, array_ops.newaxis]
     # Operator has shape [B, 1, R2, 1, C2].
     op_to_mul = operator.to_dense()[
         ..., array_ops.newaxis, :, array_ops.newaxis, :]
     # This is now [B, R1, R2, C1, C2].
     product *= op_to_mul
     # Now merge together dimensions to get [B, R1 * R2, C1 * C2].
     product = array_ops.reshape(
         product,
         shape=array_ops.concat(
             [array_ops.shape(product)[:-4],
              [array_ops.shape(product)[-4] * array_ops.shape(product)[-3],
               array_ops.shape(product)[-2] * array_ops.shape(product)[-1]]
             ], axis=0))
   product.set_shape(_ops.TensorShape(self.shape))
   return product
Example #20
0
    def _shape_tensor(self):
        # Avoid messy broadcasting if possible.
        if _ops.TensorShape(self.shape).is_fully_defined():
            return ops.convert_to_tensor(_ops.TensorShape(
                self.shape).as_list(),
                                         dtype=dtypes.int32,
                                         name="shape")

        domain_dimension = self.operators[0].domain_dimension_tensor()
        range_dimension = self.operators[0].range_dimension_tensor()
        for operator in self.operators[1:]:
            domain_dimension += operator.domain_dimension_tensor()
            range_dimension += operator.range_dimension_tensor()

        matrix_shape = array_ops.stack([domain_dimension, range_dimension])

        # Dummy Tensor of zeros.  Will never be materialized.
        zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor())
        for operator in self.operators[1:]:
            zeros += array_ops.zeros(shape=operator.batch_shape_tensor())
        batch_shape = array_ops.shape(zeros)

        return array_ops.concat((batch_shape, matrix_shape), 0)
Example #21
0
    def _shape_tensor(self):
        # Avoid messy broadcasting if possible.
        if _ops.TensorShape(self.shape).is_fully_defined():
            return ops.convert_to_tensor(_ops.TensorShape(
                self.shape).as_list(),
                                         dtype=dtypes.int32,
                                         name="shape")

        # Don't check the matrix dimensions.  That would add unnecessary Asserts to
        # the graph.  Things will fail at runtime naturally if shapes are
        # incompatible.
        matrix_shape = array_ops.stack([
            self.operators[0].range_dimension_tensor(),
            self.operators[-1].domain_dimension_tensor()
        ])

        # Dummy Tensor of zeros.  Will never be materialized.
        zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor())
        for operator in self.operators[1:]:
            zeros += array_ops.zeros(shape=operator.batch_shape_tensor())
        batch_shape = array_ops.shape(zeros)

        return array_ops.concat((batch_shape, matrix_shape), 0)
Example #22
0
    def _unblockify_then_matricize(self, vec):
        """Flatten the block dimensions then reshape to a batch matrix."""
        # Suppose
        #   _ops.TensorShape(vec.shape) = [v0, v1, v2, v3],
        #   self.block_depth = 2.
        # Then
        #   leading shape = [v0, v1]
        #   block shape = [v2, v3].
        # We will reshape vec to
        #   [v1, v2*v3, v0].

        # Un-blockify: Flatten block dimensions.  Reshape
        #   [v0, v1, v2, v3] --> [v0, v1, v2*v3].
        if _ops.TensorShape(vec.shape).is_fully_defined():
            # vec_shape = [v0, v1, v2, v3]
            vec_shape = _ops.TensorShape(vec.shape).as_list()
            # vec_leading_shape = [v0, v1]
            vec_leading_shape = vec_shape[:-self.block_depth]
            # vec_block_shape = [v2, v3]
            vec_block_shape = vec_shape[-self.block_depth:]
            # flat_shape = [v0, v1, v2*v3]
            flat_shape = vec_leading_shape + [np.prod(vec_block_shape)]
        else:
            vec_shape = array_ops.shape(vec)
            vec_leading_shape = vec_shape[:-self.block_depth]
            vec_block_shape = vec_shape[-self.block_depth:]
            flat_shape = array_ops.concat(
                (vec_leading_shape, [math_ops.reduce_prod(vec_block_shape)]),
                0)
        vec_flat = array_ops.reshape(vec, flat_shape)

        # Matricize:  Reshape to batch matrix.
        #   [v0, v1, v2*v3] --> [v1, v2*v3, v0],
        # representing a shape [v1] batch of [v2*v3, v0] matrices.
        matrix = distribution_util.rotate_transpose(vec_flat, shift=-1)
        return matrix
 def _shape_tensor(self):
     d_shape = array_ops.shape(self._diag)
     k = d_shape[-1]
     return array_ops.concat((d_shape, [k]), 0)
 def _shape_tensor(self):
     batch_shape = array_ops.broadcast_dynamic_shape(
         self.base_operator.batch_shape_tensor(),
         array_ops.shape(self.u)[:-2])
     return array_ops.concat(
         [batch_shape, self.base_operator.shape_tensor()[-2:]], axis=0)
Example #25
0
def _vec(x):
  """Stacks column of matrix to form a single column."""
  return array_ops.reshape(
      _linalg.matrix_transpose(x),
      array_ops.concat(
          [array_ops.shape(x)[:-2], [-1]], axis=0))
 def _shape_tensor(self):
   v_shape = array_ops.broadcast_dynamic_shape(
       array_ops.shape(self.row),
       array_ops.shape(self.col))
   k = v_shape[-1]
   return array_ops.concat((v_shape, [k]), 0)
Example #27
0
  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
    # Here we follow the same use of Roth's column lemma as in `matmul`, with
    # the key difference that we replace all `matmul` instances with `solve`.
    # This follows from the property that inv(A x B) = inv(A) x inv(B).

    # Below we document the shape manipulation for adjoint=False,
    # adjoint_arg=False, but the general case of different adjoints is still
    # handled.

    if adjoint_arg:
      rhs = linalg.adjoint(rhs)

    # Always add a batch dimension to enable broadcasting to work.
    batch_shape = array_ops.concat(
        [array_ops.ones_like(self.batch_shape_tensor()), [1, 1]], 0)
    rhs += array_ops.zeros(batch_shape, dtype=rhs.dtype)

    # rhs has shape [B, R, C], where B represent some number of batch
    # dimensions,
    # R represents the number of rows, and C represents the number of columns.
    # In order to apply Roth's column lemma, we need to operate on a batch of
    # column vectors, so we reshape into a batch of column vectors. We put it
    # at the front to ensure that broadcasting between operators to the batch
    # dimensions B still works.
    output = _rotate_last_dim(rhs, rotate_right=True)

    # Also expand the shape to be [A, C, B, R]. The first dimension will be
    # used to accumulate dimensions from each operator matmul.
    output = output[array_ops.newaxis, ...]

    # In this loop, A is going to refer to the value of the accumulated
    # dimension. A = 1 at the start, and will end up being self.range_dimension.
    # V will refer to the last dimension. V = R at the start, and will end up
    # being 1 in the end.
    for operator in self.operators[:-1]:
      # Reshape output from [A, C, B, V] to be
      # [A, C, B, V / op.domain_dimension, op.domain_dimension]
      if adjoint:
        operator_dimension = operator.range_dimension_tensor()
      else:
        operator_dimension = operator.domain_dimension_tensor()

      output = _unvec_by(output, operator_dimension)

      # We are computing (XA^-1^T) = (A^-1 X^T)^T.
      # output has [A, C, B, V / op.domain_dimension, op.domain_dimension],
      # which is being converted to:
      # [A, C, B, V / op.domain_dimension, op.range_dimension]
      output = _linalg.matrix_transpose(output)
      output = operator.solve(output, adjoint=adjoint, adjoint_arg=False)
      output = _linalg.matrix_transpose(output)
      # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension]
      output = _rotate_last_dim(output, rotate_right=False)
      output = _vec(output)
      output = _rotate_last_dim(output, rotate_right=True)

    # After the loop, we will have
    # A = self.range_dimension / op[-1].range_dimension
    # V = op[-1].domain_dimension

    # We convert that using matvec to get:
    # [A, C, B, op[-1].range_dimension]
    output = self.operators[-1].solvevec(output, adjoint=adjoint)
    # Rearrange shape to be [B1, ... Bn, self.range_dimension, C]
    output = _rotate_last_dim(output, rotate_right=False)
    output = _vec(output)
    output = _rotate_last_dim(output, rotate_right=False)

    if _ops.TensorShape(rhs.shape).is_fully_defined():
      column_dim = _ops.TensorShape(rhs.shape)[-1]
      broadcast_batch_shape = common_shapes.broadcast_shape(
          _ops.TensorShape(rhs.shape)[:-2], self.batch_shape)
      if adjoint:
        matrix_dimensions = [self.domain_dimension, column_dim]
      else:
        matrix_dimensions = [self.range_dimension, column_dim]

      output.set_shape(broadcast_batch_shape.concatenate(
          matrix_dimensions))

    return output
Example #28
0
def broadcast_matrix_batch_dims(batch_matrices, name=None):
    """Broadcast leading dimensions of zero or more [batch] matrices.

  Example broadcasting one batch dim of two simple matrices.

  ```python
  x = [[1, 2],
       [3, 4]]  # Shape [2, 2], no batch dims

  y = [[[1]]]   # Shape [1, 1, 1], 1 batch dim of shape [1]

  x_bc, y_bc = broadcast_matrix_batch_dims([x, y])

  x_bc
  ==> [[[1, 2],
        [3, 4]]]  # Shape [1, 2, 2], 1 batch dim of shape [1].

  y_bc
  ==> same as y
  ```

  Example broadcasting many batch dims

  ```python
  x = tf.random.normal(shape=(2, 3, 1, 4, 4))
  y = tf.random.normal(shape=(1, 3, 2, 5, 5))
  x_bc, y_bc = broadcast_matrix_batch_dims([x, y])

  _ops.TensorShape(x_bc.shape)
  ==> (2, 3, 2, 4, 4)

  _ops.TensorShape(y_bc.shape)
  ==> (2, 3, 2, 5, 5)
  ```

  Args:
    batch_matrices:  Iterable of `Tensor`s, each having two or more dimensions.
    name:  A string name to prepend to created ops.

  Returns:
    bcast_matrices: List of `Tensor`s, with `bcast_matricies[i]` containing
      the values from `batch_matrices[i]`, with possibly broadcast batch dims.

  Raises:
    ValueError:  If any input `Tensor` is statically determined to have less
      than two dimensions.
  """
    with ops.name_scope(name or "broadcast_matrix_batch_dims",
                        values=batch_matrices):
        check_ops.assert_proper_iterable(batch_matrices)
        batch_matrices = list(batch_matrices)

        for i, mat in enumerate(batch_matrices):
            batch_matrices[i] = ops.convert_to_tensor(mat)
            assert_is_batch_matrix(batch_matrices[i])

        if len(batch_matrices) < 2:
            return batch_matrices

        # Try static broadcasting.
        # bcast_batch_shape is the broadcast batch shape of ALL matrices.
        # E.g. if batch_matrices = [x, y], with
        # _ops.TensorShape(x.shape) =    [2, j, k]  (batch shape =    [2])
        # _ops.TensorShape(y.shape) = [3, 1, l, m]  (batch shape = [3, 1])
        # ==> bcast_batch_shape = [3, 2]
        bcast_batch_shape = _ops.TensorShape(batch_matrices[0].shape)[:-2]
        for mat in batch_matrices[1:]:
            bcast_batch_shape = _ops.broadcast_static_shape_as_tensorshape(
                bcast_batch_shape,
                _ops.TensorShape(mat.shape)[:-2])
        if bcast_batch_shape.is_fully_defined():
            for i, mat in enumerate(batch_matrices):
                if _ops.TensorShape(mat.shape)[:-2] != bcast_batch_shape:
                    bcast_shape = array_ops.concat([
                        bcast_batch_shape.as_list(),
                        array_ops.shape(mat)[-2:]
                    ],
                                                   axis=0)
                    batch_matrices[i] = _ops.broadcast_to(mat, bcast_shape)
            return batch_matrices

        # Since static didn't work, do dynamic, which always copies data.
        bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2]
        for mat in batch_matrices[1:]:
            bcast_batch_shape = array_ops.broadcast_dynamic_shape(
                bcast_batch_shape,
                array_ops.shape(mat)[:-2])
        for i, mat in enumerate(batch_matrices):
            batch_matrices[i] = _ops.broadcast_to(
                mat,
                array_ops.concat(
                    [bcast_batch_shape,
                     array_ops.shape(mat)[-2:]], axis=0))

        return batch_matrices
Example #29
0
def _reshape_for_efficiency(a,
                            b,
                            transpose_a=False,
                            transpose_b=False,
                            adjoint_a=False,
                            adjoint_b=False):
    """Maybe reshape a, b, and return an inverse map.  For matmul/solve."""
    def identity(x):
        return x

    # At this point, we have not taken transpose/adjoint of a/b.
    still_need_to_transpose = True

    if _ops.TensorShape(a.shape).ndims is None or _ops.TensorShape(
            b.shape).ndims is None:
        return a, b, identity, still_need_to_transpose

    # This could be handled in the future, but seems less common.
    if _ops.TensorShape(a.shape).ndims >= _ops.TensorShape(b.shape).ndims:
        return a, b, identity, still_need_to_transpose

    # From now on, we might modify b, but will not modify a.

    # Suppose:
    #   _ops.TensorShape(a.shape) =     C + [m, n], _ops.TensorShape(b.shape) =
    #   _ops.TensorShape(b.shape) = S + C + [n, r]
    b_extra_ndims = _ops.TensorShape(b.shape).ndims - _ops.TensorShape(
        a.shape).ndims

    # b_extra_sh = S, b_main_sh = C + [n, r]
    b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
    b_main_sh = array_ops.shape(b)[b_extra_ndims:]

    # No reason to flip unless the extra dims of b are big enough.  Why?
    # Assume adjoint/transpose = False.  Then...
    # By not flipping, we have to replicate a to shape
    #   b_extra_sh + _ops.TensorShape(a.shape),
    # which could use extra memory.  But in all cases, the final output has shape
    #   b_extra_sh + _ops.TensorShape(a.shape)[:-1] + _ops.TensorShape([b.shape)[-1]]
    # So we only end up creating a larger object if the end dim of b is smaller
    # than the end dim of a.  This often happens, e.g. if b was a vector that was
    # expanded to a matrix (by appending a singleton).

    # Since adjoint/transpose may not be False, we must make adjustments here.
    # The dim of b that holds the multiple equations.
    a_domain_sz_ = _ops.TensorShape(
        a.shape)[-2 if adjoint_a or transpose_a else -1]
    b_eq_sz_ = _ops.TensorShape(
        b.shape)[-2 if adjoint_b or transpose_b else -1]
    b_extra_sz_ = (
        np.prod(_ops.TensorShape(b.shape)[:b_extra_ndims].as_list()) if
        _ops.TensorShape(b.shape)[:b_extra_ndims].is_fully_defined() else None)
    if (a_domain_sz_ is not None and b_eq_sz_ is not None
            and b_extra_sz_ is not None):
        if b_extra_sz_ < 2 or a_domain_sz_ <= b_eq_sz_:
            return a, b, identity, still_need_to_transpose

    # At this point, we're flipping for sure!
    # Any transposes/adjoints will happen here explicitly, rather than in calling
    # code.  Why?  To avoid having to write separate complex code for each case.
    if adjoint_a:
        a = linalg.adjoint(a)
    elif transpose_a:
        a = linalg.transpose(a)
    if adjoint_b:
        b = linalg.adjoint(b)
    elif transpose_b:
        b = linalg.transpose(b)
    still_need_to_transpose = False

    # Recompute shapes, since the transpose/adjoint may have changed them.
    b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
    b_main_sh = array_ops.shape(b)[b_extra_ndims:]

    # Permutation to put the extra dims at the end.
    perm = (np.concatenate(
        (np.arange(b_extra_ndims,
                   _ops.TensorShape(b.shape).ndims), np.arange(
                       0, b_extra_ndims)), 0))
    b_extra_on_end = array_ops.transpose(b, perm=perm)

    # Now squash this end into one long dim.
    b_squashed_end = array_ops.reshape(
        b_extra_on_end, array_ops.concat((b_main_sh[:-1], [-1]), 0))

    def reshape_inv(y):
        # Expand the extra dims hanging off the end, "b_extra_sh".
        # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y
        # Could have different batch dims than a and b, because of broadcasting.
        y_extra_shape = array_ops.concat(
            (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0)
        y_extra_on_end = array_ops.reshape(y, y_extra_shape)
        inverse_perm = np.argsort(perm)
        return array_ops.transpose(y_extra_on_end, perm=inverse_perm)

    return a, b_squashed_end, reshape_inv, still_need_to_transpose
Example #30
0
  def _matmul(self, x, adjoint=False, adjoint_arg=False):
    # Here we heavily rely on Roth's column Lemma [1]:
    # (A x B) * vec X = vec BXA^T,
    # where vec stacks all the columns of the matrix under each other. In our
    # case, x represents a batch of vec X (i.e. we think of x as a batch of
    # column vectors, rather than a matrix). Each member of the batch can be
    # reshaped to a matrix (hence we get a batch of matrices).
    # We can iteratively apply this lemma by noting that if B is a Kronecker
    # product, then we can apply the lemma again.

    # [1] W. E. Roth, "On direct product matrices,"
    # Bulletin of the American Mathematical Society, vol. 40, pp. 461-468,
    # 1934

    # Efficiency

    # Naively doing the Kronecker product, by calculating the dense matrix and
    # applying it will can take cubic time in  the size of domain_dimension
    # (assuming a square matrix). The other issue is that calculating the dense
    # matrix can be prohibitively expensive, in that it can take a large amount
    # of memory.
    #
    # This implementation avoids this memory blow up by only computing matmuls
    # with the factors. In this way, we don't have to realize the dense matrix.
    # In terms of complexity, if we have Kronecker Factors of size:
    # (n1, n1), (n2, n2), (n3, n3), ... (nJ, nJ), with N = \prod n_i, and we
    # have as input a [N, M] matrix, the naive approach would take O(N^2 M).
    # With this approach (ignoring reshaping of tensors and transposes for now),
    # the time complexity can be O(M * (\sum n_i) * N). There is also the
    # benefit of batched multiplication (In this example, the batch size is
    # roughly M * N) so this can be much faster. However, not factored in are
    # the costs of the several transposing of tensors, which can affect cache
    # behavior.

    # Below we document the shape manipulation for adjoint=False,
    # adjoint_arg=False, but the general case of different adjoints is still
    # handled.

    if adjoint_arg:
      x = linalg.adjoint(x)

    # Always add a batch dimension to enable broadcasting to work.
    batch_shape = array_ops.concat(
        [array_ops.ones_like(self.batch_shape_tensor()), [1, 1]], 0)
    x += array_ops.zeros(batch_shape, dtype=x.dtype)

    # x has shape [B, R, C], where B represent some number of batch dimensions,
    # R represents the number of rows, and C represents the number of columns.
    # In order to apply Roth's column lemma, we need to operate on a batch of
    # column vectors, so we reshape into a batch of column vectors. We put it
    # at the front to ensure that broadcasting between operators to the batch
    # dimensions B still works.
    output = _rotate_last_dim(x, rotate_right=True)

    # Also expand the shape to be [A, C, B, R]. The first dimension will be
    # used to accumulate dimensions from each operator matmul.
    output = output[array_ops.newaxis, ...]

    # In this loop, A is going to refer to the value of the accumulated
    # dimension. A = 1 at the start, and will end up being self.range_dimension.
    # V will refer to the last dimension. V = R at the start, and will end up
    # being 1 in the end.
    for operator in self.operators[:-1]:
      # Reshape output from [A, C, B, V] to be
      # [A, C, B, V / op.domain_dimension, op.domain_dimension]
      if adjoint:
        operator_dimension = operator.range_dimension_tensor()
      else:
        operator_dimension = operator.domain_dimension_tensor()

      output = _unvec_by(output, operator_dimension)

      # We are computing (XA^T) = (AX^T)^T.
      # output has [A, C, B, V / op.domain_dimension, op.domain_dimension],
      # which is being converted to:
      # [A, C, B, V / op.domain_dimension, op.range_dimension]
      output = _linalg.matrix_transpose(output)
      output = operator.matmul(output, adjoint=adjoint, adjoint_arg=False)
      output = _linalg.matrix_transpose(output)
      # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension]
      output = _rotate_last_dim(output, rotate_right=False)
      output = _vec(output)
      output = _rotate_last_dim(output, rotate_right=True)

    # After the loop, we will have
    # A = self.range_dimension / op[-1].range_dimension
    # V = op[-1].domain_dimension

    # We convert that using matvec to get:
    # [A, C, B, op[-1].range_dimension]
    output = self.operators[-1].matvec(output, adjoint=adjoint)
    # Rearrange shape to be [B1, ... Bn, self.range_dimension, C]
    output = _rotate_last_dim(output, rotate_right=False)
    output = _vec(output)
    output = _rotate_last_dim(output, rotate_right=False)

    if _ops.TensorShape(x.shape).is_fully_defined():
      column_dim = _ops.TensorShape(x.shape)[-1]
      broadcast_batch_shape = common_shapes.broadcast_shape(
          _ops.TensorShape(x.shape)[:-2], self.batch_shape)
      if adjoint:
        matrix_dimensions = [self.domain_dimension, column_dim]
      else:
        matrix_dimensions = [self.range_dimension, column_dim]

      output.set_shape(broadcast_batch_shape.concatenate(
          matrix_dimensions))

    return output