def _unvec_by(y, num_col):
    """Unstack vector to form a matrix, with a specified amount of columns."""
    return _linalg.matrix_transpose(
        array_ops.reshape(
            y,
            array_ops.concat([array_ops.shape(y)[:-1], [num_col, -1]],
                             axis=0)))
 def _to_dense(self):
     product = self.operators[0].to_dense()
     for operator in self.operators[1:]:
         # Product has shape [B, R1, 1, C1, 1].
         product = product[..., :, _ops.newaxis, :, _ops.newaxis]
         # Operator has shape [B, 1, R2, 1, C2].
         op_to_mul = operator.to_dense()[..., _ops.newaxis, :,
                                         _ops.newaxis, :]
         # This is now [B, R1, R2, C1, C2].
         product = product * op_to_mul
         # Now merge together dimensions to get [B, R1 * R2, C1 * C2].
         product = array_ops.reshape(product,
                                     shape=array_ops.concat([
                                         array_ops.shape(product)[:-4],
                                         [
                                             array_ops.shape(product)[-4] *
                                             array_ops.shape(product)[-3],
                                             array_ops.shape(product)[-2] *
                                             array_ops.shape(product)[-1]
                                         ]
                                     ],
                                                            axis=0))
     tensorshape_util.set_shape(product,
                                tensor_shape.TensorShape(self.shape))
     return product
Example #3
0
    def _vectorize_then_blockify(self, matrix):
        """Shape batch matrix to batch vector, then blockify trailing dimensions."""
        # Suppose
        #   tensor_shape.TensorShape(matrix.shape) = [m0, m1, m2, m3],
        # and matrix is a matrix because the final two dimensions are matrix dims.
        #   self.block_depth = 2,
        #   self.block_shape = [b0, b1]  (note b0 * b1 = m2).
        # We will reshape matrix to
        #   [m3, m0, m1, b0, b1].

        # Vectorize: Reshape to batch vector.
        #   [m0, m1, m2, m3] --> [m3, m0, m1, m2]
        # This is called "vectorize" because we have taken the final two matrix dims
        # and turned this into a size m3 batch of vectors.
        vec = distribution_util.rotate_transpose(matrix, shift=1)

        # Blockify: Blockfy trailing dimensions.
        #   [m3, m0, m1, m2] --> [m3, m0, m1, b0, b1]
        if (tensor_shape.TensorShape(vec.shape).is_fully_defined()
                and self.block_shape.is_fully_defined()):
            # vec_leading_shape = [m3, m0, m1],
            # the parts of vec that will not be blockified.
            vec_leading_shape = tensor_shape.TensorShape(vec.shape)[:-1]
            final_shape = vec_leading_shape.concatenate(self.block_shape)
        else:
            vec_leading_shape = prefer_static.shape(vec)[:-1]
            final_shape = prefer_static.concat(
                (vec_leading_shape, self.block_shape_tensor()), 0)
        return array_ops.reshape(vec, final_shape)
Example #4
0
  def _unblockify_then_matricize(self, vec):
    """Flatten the block dimensions then reshape to a batch matrix."""
    # Suppose
    #   tensor_shape.TensorShape(vec.shape) = [v0, v1, v2, v3],
    #   self.block_depth = 2.
    # Then
    #   leading shape = [v0, v1]
    #   block shape = [v2, v3].
    # We will reshape vec to
    #   [v1, v2*v3, v0].

    # Un-blockify: Flatten block dimensions.  Reshape
    #   [v0, v1, v2, v3] --> [v0, v1, v2*v3].
    if tensor_shape.TensorShape(vec.shape).is_fully_defined():
      # vec_shape = [v0, v1, v2, v3]
      vec_shape = tensor_shape.TensorShape(vec.shape).as_list()
      # vec_leading_shape = [v0, v1]
      vec_leading_shape = vec_shape[:-self.block_depth]
      # vec_block_shape = [v2, v3]
      vec_block_shape = vec_shape[-self.block_depth:]
      # flat_shape = [v0, v1, v2*v3]
      flat_shape = vec_leading_shape + [np.prod(vec_block_shape)]
    else:
      vec_shape = array_ops.shape(vec)
      vec_leading_shape = vec_shape[:-self.block_depth]
      vec_block_shape = vec_shape[-self.block_depth:]
      flat_shape = array_ops.concat(
          (vec_leading_shape, [math_ops.reduce_prod(vec_block_shape)]), 0)
    vec_flat = array_ops.reshape(vec, flat_shape)

    # Matricize:  Reshape to batch matrix.
    #   [v0, v1, v2*v3] --> [v1, v2*v3, v0],
    # representing a shape [v1] batch of [v2*v3, v0] matrices.
    matrix = distribution_util.rotate_transpose(vec_flat, shift=-1)
    return matrix
Example #5
0
 def reshape_inv(y):
   # Expand the extra dims hanging off the end, "b_extra_sh".
   # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y
   # Could have different batch dims than a and b, because of broadcasting.
   y_extra_shape = array_ops.concat(
       (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0)
   y_extra_on_end = array_ops.reshape(y, y_extra_shape)
   inverse_perm = np.argsort(perm)
   return array_ops.transpose(y_extra_on_end, perm=inverse_perm)
Example #6
0
  def _broadcast_batch_dims(self, x, spectrum):
    """Broadcast batch dims of batch matrix `x` and spectrum."""
    # _ops.TensorShape(spectrum.shape) = batch_shape + block_shape
    # First make spectrum a batch matrix with
    #   _ops.TensorShape(spectrum.shape) = batch_shape + [prod(block_shape), 1]
    spec_mat = array_ops.reshape(
        spectrum, array_ops.concat(
            (self.batch_shape_tensor(), [-1, 1]), axis=0))
    # Second, broadcast, possibly requiring an addition of array of zeros.
    x, spec_mat = linear_operator_util.broadcast_matrix_batch_dims((x,
                                                                    spec_mat))
    # Third, put the block shape back into spectrum.
    batch_shape = array_ops.shape(x)[:-2]
    spectrum = array_ops.reshape(
        spec_mat,
        array_ops.concat((batch_shape, self.block_shape_tensor()), axis=0))

    return x, spectrum
Example #7
0
    def _broadcast_batch_dims(self, x, spectrum):
        """Broadcast batch dims of batch matrix `x` and spectrum."""
        spectrum = ops.convert_to_tensor(spectrum, name="spectrum")
        # tensor_shape.TensorShape(spectrum.shape) = batch_shape + block_shape
        # First make spectrum a batch matrix with
        #   tensor_shape.TensorShape(spectrum.shape) = batch_shape + [prod(block_shape), 1]
        batch_shape = self._batch_shape_tensor(shape=self._shape_tensor(
            spectrum=spectrum))
        spec_mat = array_ops.reshape(
            spectrum, prefer_static.concat((batch_shape, [-1, 1]), axis=0))
        # Second, broadcast, possibly requiring an addition of array of zeros.
        x, spec_mat = linear_operator_util.broadcast_matrix_batch_dims(
            (x, spec_mat))
        # Third, put the block shape back into spectrum.
        x_batch_shape = prefer_static.shape(x)[:-2]
        spectrum_shape = prefer_static.shape(spectrum)
        spectrum = array_ops.reshape(
            spec_mat,
            prefer_static.concat(
                (x_batch_shape,
                 self._block_shape_tensor(spectrum_shape=spectrum_shape)),
                axis=0))

        return x, spectrum
 def _diag_part(self):
     diag_part = self.operators[0].diag_part()
     for operator in self.operators[1:]:
         diag_part = diag_part[..., :, array_ops.newaxis]
         op_diag_part = operator.diag_part()[..., array_ops.newaxis, :]
         diag_part *= op_diag_part
         diag_part = array_ops.reshape(
             diag_part,
             shape=array_ops.concat([array_ops.shape(diag_part)[:-2], [-1]],
                                    axis=0))
     if self.range_dimension > self.domain_dimension:
         diag_dimension = self.domain_dimension
     else:
         diag_dimension = self.range_dimension
     diag_part.set_shape(self.batch_shape.concatenate(diag_dimension))
     return diag_part
Example #9
0
 def _eigvals(self):
   # This will be the kronecker product of all the eigenvalues.
   # Note: It doesn't matter which kronecker product it is, since every
   # kronecker product of the same matrices are similar.
   eigvals = [operator.eigvals() for operator in self.operators]
   # Now compute the kronecker product
   product = eigvals[0]
   for eigval in eigvals[1:]:
     # Product has shape [B, R1, 1].
     product = product[..., _ops.newaxis]
     # Eigval has shape [B, 1, R2]. Produces shape [B, R1, R2].
     product = product * eigval[..., _ops.newaxis, :]
     # Reshape to [B, R1 * R2]
     product = array_ops.reshape(
         product,
         shape=prefer_static.concat([prefer_static.shape(product)[:-2], [-1]], axis=0))
   tensorshape_util.set_shape(product, tensor_shape.TensorShape(self.shape)[:-1])
   return product
def _vec(x):
    """Stacks column of matrix to form a single column."""
    return array_ops.reshape(
        _linalg.matrix_transpose(x),
        array_ops.concat([array_ops.shape(x)[:-2], [-1]], axis=0))
Example #11
0
def _reshape_for_efficiency(a,
                            b,
                            transpose_a=False,
                            transpose_b=False,
                            adjoint_a=False,
                            adjoint_b=False):
  """Maybe reshape a, b, and return an inverse map.  For matmul/solve."""
  def identity(x):
    return x

  # At this point, we have not taken transpose/adjoint of a/b.
  still_need_to_transpose = True

  if tensor_shape.TensorShape(a.shape).ndims is None or tensor_shape.TensorShape(b.shape).ndims is None:
    return a, b, identity, still_need_to_transpose

  # This could be handled in the future, but seems less common.
  if tensor_shape.TensorShape(a.shape).ndims >= tensor_shape.TensorShape(b.shape).ndims:
    return a, b, identity, still_need_to_transpose

  # From now on, we might modify b, but will not modify a.

  # Suppose:
  #   tensor_shape.TensorShape(a.shape) =     C + [m, n], tensor_shape.TensorShape(b.shape) =
  #   tensor_shape.TensorShape(b.shape) = S + C + [n, r]
  b_extra_ndims = tensor_shape.TensorShape(b.shape).ndims - tensor_shape.TensorShape(a.shape).ndims

  # b_extra_sh = S, b_main_sh = C + [n, r]
  b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
  b_main_sh = array_ops.shape(b)[b_extra_ndims:]

  # No reason to flip unless the extra dims of b are big enough.  Why?
  # Assume adjoint/transpose = False.  Then...
  # By not flipping, we have to replicate a to shape
  #   b_extra_sh + tensor_shape.TensorShape(a.shape),
  # which could use extra memory.  But in all cases, the final output has shape
  #   b_extra_sh + tensor_shape.TensorShape(a.shape)[:-1] + tensor_shape.TensorShape([b.shape)[-1]]
  # So we only end up creating a larger object if the end dim of b is smaller
  # than the end dim of a.  This often happens, e.g. if b was a vector that was
  # expanded to a matrix (by appending a singleton).

  # Since adjoint/transpose may not be False, we must make adjustments here.
  # The dim of b that holds the multiple equations.
  a_domain_sz_ = tensor_shape.TensorShape(a.shape)[-2 if adjoint_a or transpose_a else -1]
  b_eq_sz_ = tensor_shape.TensorShape(b.shape)[-2 if adjoint_b or transpose_b else -1]
  b_extra_sz_ = (
      np.prod(tensor_shape.TensorShape(b.shape)[:b_extra_ndims].as_list())
      if tensor_shape.TensorShape(b.shape)[:b_extra_ndims].is_fully_defined() else None)
  if (a_domain_sz_ is not None and b_eq_sz_ is not None and
      b_extra_sz_ is not None):
    if b_extra_sz_ < 2 or a_domain_sz_ <= b_eq_sz_:
      return a, b, identity, still_need_to_transpose

  # At this point, we're flipping for sure!
  # Any transposes/adjoints will happen here explicitly, rather than in calling
  # code.  Why?  To avoid having to write separate complex code for each case.
  if adjoint_a:
    a = _linalg.matrix_transpose(a, conjugate=True)
  elif transpose_a:
    a = _linalg.matrix_transpose(a, conjugate=False)
  if adjoint_b:
    b = _linalg.matrix_transpose(b, conjugate=True)
  elif transpose_a:
    b = _linalg.matrix_transpose(b, conjugate=False)
  still_need_to_transpose = False

  # Recompute shapes, since the transpose/adjoint may have changed them.
  b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
  b_main_sh = array_ops.shape(b)[b_extra_ndims:]

  # Permutation to put the extra dims at the end.
  perm = (
      np.concatenate(
          (np.arange(b_extra_ndims, tensor_shape.TensorShape(b.shape).ndims),
           np.arange(0, b_extra_ndims)), 0))
  b_extra_on_end = array_ops.transpose(b, perm=perm)

  # Now squash this end into one long dim.
  b_squashed_end = array_ops.reshape(
      b_extra_on_end, array_ops.concat((b_main_sh[:-1], [-1]), 0))

  def reshape_inv(y):
    # Expand the extra dims hanging off the end, "b_extra_sh".
    # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y
    # Could have different batch dims than a and b, because of broadcasting.
    y_extra_shape = array_ops.concat(
        (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0)
    y_extra_on_end = array_ops.reshape(y, y_extra_shape)
    inverse_perm = np.argsort(perm)
    return array_ops.transpose(y_extra_on_end, perm=inverse_perm)

  return a, b_squashed_end, reshape_inv, still_need_to_transpose
Example #12
0
    def _solve_matmul_internal(self,
                               x,
                               solve_matmul_fn,
                               adjoint=False,
                               adjoint_arg=False):
        # We heavily rely on Roth's column Lemma [1]:
        # (A x B) * vec X = vec BXA^T
        # where vec stacks all the columns of the matrix under each other.
        # In our case, we use a variant of the lemma that is row-major
        # friendly: (A x B) * vec' X = vec' AXB^T
        # Where vec' reshapes a matrix into a vector. We can repeatedly apply this
        # for a collection of kronecker products.
        # Given that (A x B)^-1 = A^-1 x B^-1 and (A x B)^T = A^T x B^T, we can
        # use the above to compute multiplications, solves with any composition of
        # transposes.
        output = x

        if adjoint_arg:
            if np.issubdtype(self.dtype, np.complexfloating):
                output = math_ops.conj(output)
        else:
            output = linalg.transpose(output)

        for o in reversed(self.operators):
            # Statically compute the reshape.
            if adjoint:
                operator_dimension = o.range_dimension_tensor()
            else:
                operator_dimension = o.domain_dimension_tensor()
            output_shape = _prefer_static_shape(output)

            if ops.get_static_value(operator_dimension) is not None:
                operator_dimension = ops.get_static_value(operator_dimension)
                if tensor_shape.TensorShape(
                        output.shape
                )[-2] is not None and tensor_shape.TensorShape(
                        output.shape)[-1] is not None:
                    dim = int(
                        tensor_shape.TensorShape(output.shape)[-2] *
                        output_shape[-1] // operator_dimension)
            else:
                dim = _ops.cast(output_shape[-2] * output_shape[-1] //
                                operator_dimension,
                                dtype=dtypes.int32)

            output_shape = _prefer_static_concat_shape(
                output_shape[:-2], [dim, operator_dimension])
            output = array_ops.reshape(output, shape=output_shape)

            # Conjugate because we are trying to compute A @ B^T, but
            # `LinearOperator` only supports `adjoint_arg`.
            if np.issubdtype(self.dtype, np.complexfloating):
                output = math_ops.conj(output)

            output = solve_matmul_fn(o,
                                     output,
                                     adjoint=adjoint,
                                     adjoint_arg=True)

        if adjoint_arg:
            col_dim = _prefer_static_shape(x)[-2]
        else:
            col_dim = _prefer_static_shape(x)[-1]

        if adjoint:
            row_dim = self.domain_dimension_tensor()
        else:
            row_dim = self.range_dimension_tensor()

        matrix_shape = [row_dim, col_dim]

        output = array_ops.reshape(
            output,
            _prefer_static_concat_shape(
                _prefer_static_shape(output)[:-2], matrix_shape))

        if tensor_shape.TensorShape(x.shape).is_fully_defined():
            if adjoint_arg:
                column_dim = tensor_shape.TensorShape(x.shape)[-2]
            else:
                column_dim = tensor_shape.TensorShape(x.shape)[-1]
            broadcast_batch_shape = common_shapes.broadcast_shape(
                tensor_shape.TensorShape(x.shape)[:-2], self.batch_shape)
            if adjoint:
                matrix_dimensions = [self.domain_dimension, column_dim]
            else:
                matrix_dimensions = [self.range_dimension, column_dim]

            tensorshape_util.set_shape(
                output, broadcast_batch_shape.concatenate(matrix_dimensions))

        return output