Exemplo n.º 1
0
def _unvec_by(y, num_col):
  """Unstack vector to form a matrix, with a specified amount of columns."""
  return _linalg.matrix_transpose(
      array_ops.reshape(
          y,
          array_ops.concat(
              [array_ops.shape(y)[:-1], [num_col, -1]], axis=0)))
Exemplo n.º 2
0
    def _vectorize_then_blockify(self, matrix):
        """Shape batch matrix to batch vector, then blockify trailing dimensions."""
        # Suppose
        #   _ops.TensorShape(matrix.shape) = [m0, m1, m2, m3],
        # and matrix is a matrix because the final two dimensions are matrix dims.
        #   self.block_depth = 2,
        #   self.block_shape = [b0, b1]  (note b0 * b1 = m2).
        # We will reshape matrix to
        #   [m3, m0, m1, b0, b1].

        # Vectorize: Reshape to batch vector.
        #   [m0, m1, m2, m3] --> [m3, m0, m1, m2]
        # This is called "vectorize" because we have taken the final two matrix dims
        # and turned this into a size m3 batch of vectors.
        vec = distribution_util.rotate_transpose(matrix, shift=1)

        # Blockify: Blockfy trailing dimensions.
        #   [m3, m0, m1, m2] --> [m3, m0, m1, b0, b1]
        if (_ops.TensorShape(vec.shape).is_fully_defined()
                and self.block_shape.is_fully_defined()):
            # vec_leading_shape = [m3, m0, m1],
            # the parts of vec that will not be blockified.
            vec_leading_shape = _ops.TensorShape(vec.shape)[:-1]
            final_shape = vec_leading_shape.concatenate(self.block_shape)
        else:
            vec_leading_shape = array_ops.shape(vec)[:-1]
            final_shape = array_ops.concat(
                (vec_leading_shape, self.block_shape_tensor()), 0)
        return array_ops.reshape(vec, final_shape)
Exemplo n.º 3
0
    def _broadcast_batch_dims(self, x, spectrum):
        """Broadcast batch dims of batch matrix `x` and spectrum."""
        # _ops.TensorShape(spectrum.shape) = batch_shape + block_shape
        # First make spectrum a batch matrix with
        #   _ops.TensorShape(spectrum.shape) = batch_shape + [prod(block_shape), 1]
        spec_mat = array_ops.reshape(
            spectrum,
            array_ops.concat((self.batch_shape_tensor(), [-1, 1]), axis=0))
        # Second, broadcast, possibly requiring an addition of array of zeros.
        x, spec_mat = linear_operator_util.broadcast_matrix_batch_dims(
            (x, spec_mat))
        # Third, put the block shape back into spectrum.
        batch_shape = array_ops.shape(x)[:-2]
        spectrum = array_ops.reshape(
            spec_mat,
            array_ops.concat((batch_shape, self.block_shape_tensor()), axis=0))

        return x, spectrum
Exemplo n.º 4
0
 def reshape_inv(y):
     # Expand the extra dims hanging off the end, "b_extra_sh".
     # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y
     # Could have different batch dims than a and b, because of broadcasting.
     y_extra_shape = array_ops.concat(
         (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0)
     y_extra_on_end = array_ops.reshape(y, y_extra_shape)
     inverse_perm = np.argsort(perm)
     return array_ops.transpose(y_extra_on_end, perm=inverse_perm)
Exemplo n.º 5
0
 def _diag_part(self):
   diag_part = self.operators[0].diag_part()
   for operator in self.operators[1:]:
     diag_part = diag_part[..., :, array_ops.newaxis]
     op_diag_part = operator.diag_part()[..., array_ops.newaxis, :]
     diag_part *= op_diag_part
     diag_part = array_ops.reshape(
         diag_part,
         shape=array_ops.concat(
             [array_ops.shape(diag_part)[:-2], [-1]], axis=0))
   if self.range_dimension > self.domain_dimension:
     diag_dimension = self.domain_dimension
   else:
     diag_dimension = self.range_dimension
   diag_part.set_shape(
       self.batch_shape.concatenate(diag_dimension))
   return diag_part
Exemplo n.º 6
0
 def _to_dense(self):
   product = self.operators[0].to_dense()
   for operator in self.operators[1:]:
     # Product has shape [B, R1, 1, C1].
     product = product[
         ..., :, array_ops.newaxis, :, array_ops.newaxis]
     # Operator has shape [B, 1, R2, 1, C2].
     op_to_mul = operator.to_dense()[
         ..., array_ops.newaxis, :, array_ops.newaxis, :]
     # This is now [B, R1, R2, C1, C2].
     product *= op_to_mul
     # Now merge together dimensions to get [B, R1 * R2, C1 * C2].
     product = array_ops.reshape(
         product,
         shape=array_ops.concat(
             [array_ops.shape(product)[:-4],
              [array_ops.shape(product)[-4] * array_ops.shape(product)[-3],
               array_ops.shape(product)[-2] * array_ops.shape(product)[-1]]
             ], axis=0))
   product.set_shape(_ops.TensorShape(self.shape))
   return product
Exemplo n.º 7
0
    def _unblockify_then_matricize(self, vec):
        """Flatten the block dimensions then reshape to a batch matrix."""
        # Suppose
        #   _ops.TensorShape(vec.shape) = [v0, v1, v2, v3],
        #   self.block_depth = 2.
        # Then
        #   leading shape = [v0, v1]
        #   block shape = [v2, v3].
        # We will reshape vec to
        #   [v1, v2*v3, v0].

        # Un-blockify: Flatten block dimensions.  Reshape
        #   [v0, v1, v2, v3] --> [v0, v1, v2*v3].
        if _ops.TensorShape(vec.shape).is_fully_defined():
            # vec_shape = [v0, v1, v2, v3]
            vec_shape = _ops.TensorShape(vec.shape).as_list()
            # vec_leading_shape = [v0, v1]
            vec_leading_shape = vec_shape[:-self.block_depth]
            # vec_block_shape = [v2, v3]
            vec_block_shape = vec_shape[-self.block_depth:]
            # flat_shape = [v0, v1, v2*v3]
            flat_shape = vec_leading_shape + [np.prod(vec_block_shape)]
        else:
            vec_shape = array_ops.shape(vec)
            vec_leading_shape = vec_shape[:-self.block_depth]
            vec_block_shape = vec_shape[-self.block_depth:]
            flat_shape = array_ops.concat(
                (vec_leading_shape, [math_ops.reduce_prod(vec_block_shape)]),
                0)
        vec_flat = array_ops.reshape(vec, flat_shape)

        # Matricize:  Reshape to batch matrix.
        #   [v0, v1, v2*v3] --> [v1, v2*v3, v0],
        # representing a shape [v1] batch of [v2*v3, v0] matrices.
        matrix = distribution_util.rotate_transpose(vec_flat, shift=-1)
        return matrix
Exemplo n.º 8
0
def _reshape_for_efficiency(a,
                            b,
                            transpose_a=False,
                            transpose_b=False,
                            adjoint_a=False,
                            adjoint_b=False):
    """Maybe reshape a, b, and return an inverse map.  For matmul/solve."""
    def identity(x):
        return x

    # At this point, we have not taken transpose/adjoint of a/b.
    still_need_to_transpose = True

    if _ops.TensorShape(a.shape).ndims is None or _ops.TensorShape(
            b.shape).ndims is None:
        return a, b, identity, still_need_to_transpose

    # This could be handled in the future, but seems less common.
    if _ops.TensorShape(a.shape).ndims >= _ops.TensorShape(b.shape).ndims:
        return a, b, identity, still_need_to_transpose

    # From now on, we might modify b, but will not modify a.

    # Suppose:
    #   _ops.TensorShape(a.shape) =     C + [m, n], _ops.TensorShape(b.shape) =
    #   _ops.TensorShape(b.shape) = S + C + [n, r]
    b_extra_ndims = _ops.TensorShape(b.shape).ndims - _ops.TensorShape(
        a.shape).ndims

    # b_extra_sh = S, b_main_sh = C + [n, r]
    b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
    b_main_sh = array_ops.shape(b)[b_extra_ndims:]

    # No reason to flip unless the extra dims of b are big enough.  Why?
    # Assume adjoint/transpose = False.  Then...
    # By not flipping, we have to replicate a to shape
    #   b_extra_sh + _ops.TensorShape(a.shape),
    # which could use extra memory.  But in all cases, the final output has shape
    #   b_extra_sh + _ops.TensorShape(a.shape)[:-1] + _ops.TensorShape([b.shape)[-1]]
    # So we only end up creating a larger object if the end dim of b is smaller
    # than the end dim of a.  This often happens, e.g. if b was a vector that was
    # expanded to a matrix (by appending a singleton).

    # Since adjoint/transpose may not be False, we must make adjustments here.
    # The dim of b that holds the multiple equations.
    a_domain_sz_ = _ops.TensorShape(
        a.shape)[-2 if adjoint_a or transpose_a else -1]
    b_eq_sz_ = _ops.TensorShape(
        b.shape)[-2 if adjoint_b or transpose_b else -1]
    b_extra_sz_ = (
        np.prod(_ops.TensorShape(b.shape)[:b_extra_ndims].as_list()) if
        _ops.TensorShape(b.shape)[:b_extra_ndims].is_fully_defined() else None)
    if (a_domain_sz_ is not None and b_eq_sz_ is not None
            and b_extra_sz_ is not None):
        if b_extra_sz_ < 2 or a_domain_sz_ <= b_eq_sz_:
            return a, b, identity, still_need_to_transpose

    # At this point, we're flipping for sure!
    # Any transposes/adjoints will happen here explicitly, rather than in calling
    # code.  Why?  To avoid having to write separate complex code for each case.
    if adjoint_a:
        a = linalg.adjoint(a)
    elif transpose_a:
        a = linalg.transpose(a)
    if adjoint_b:
        b = linalg.adjoint(b)
    elif transpose_b:
        b = linalg.transpose(b)
    still_need_to_transpose = False

    # Recompute shapes, since the transpose/adjoint may have changed them.
    b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
    b_main_sh = array_ops.shape(b)[b_extra_ndims:]

    # Permutation to put the extra dims at the end.
    perm = (np.concatenate(
        (np.arange(b_extra_ndims,
                   _ops.TensorShape(b.shape).ndims), np.arange(
                       0, b_extra_ndims)), 0))
    b_extra_on_end = array_ops.transpose(b, perm=perm)

    # Now squash this end into one long dim.
    b_squashed_end = array_ops.reshape(
        b_extra_on_end, array_ops.concat((b_main_sh[:-1], [-1]), 0))

    def reshape_inv(y):
        # Expand the extra dims hanging off the end, "b_extra_sh".
        # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y
        # Could have different batch dims than a and b, because of broadcasting.
        y_extra_shape = array_ops.concat(
            (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0)
        y_extra_on_end = array_ops.reshape(y, y_extra_shape)
        inverse_perm = np.argsort(perm)
        return array_ops.transpose(y_extra_on_end, perm=inverse_perm)

    return a, b_squashed_end, reshape_inv, still_need_to_transpose
Exemplo n.º 9
0
def _vec(x):
  """Stacks column of matrix to form a single column."""
  return array_ops.reshape(
      _linalg.matrix_transpose(x),
      array_ops.concat(
          [array_ops.shape(x)[:-2], [-1]], axis=0))