Ejemplo n.º 1
0
def _reshape_for_efficiency(a,
                            b,
                            transpose_a=False,
                            transpose_b=False,
                            adjoint_a=False,
                            adjoint_b=False):
    """Maybe reshape a, b, and return an inverse map.  For matmul/solve."""
    def identity(x):
        return x

    # At this point, we have not taken transpose/adjoint of a/b.
    still_need_to_transpose = True

    if a.shape.ndims is None or b.shape.ndims is None:
        return a, b, identity, still_need_to_transpose

    # This could be handled in the future, but seems less common.
    if a.shape.ndims >= b.shape.ndims:
        return a, b, identity, still_need_to_transpose

    # From now on, we might modify b, but will not modify a.

    # Suppose:
    #   a.shape =     C + [m, n], b.shape =
    #   b.shape = S + C + [n, r]
    b_extra_ndims = b.shape.ndims - a.shape.ndims

    # b_extra_sh = S, b_main_sh = C + [n, r]
    b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
    b_main_sh = array_ops.shape(b)[b_extra_ndims:]

    # No reason to flip unless the extra dims of b are big enough.  Why?
    # Assume adjoint/transpose = False.  Then...
    # By not flipping, we have to replicate a to shape
    #   b_extra_sh + a.shape,
    # which could use extra memory.  But in all cases, the final output has shape
    #   b_extra_sh + a.shape[:-1] + [b.shape[-1]]
    # So we only end up creating a larger object if the end dim of b is smaller
    # than the end dim of a.  This often happens, e.g. if b was a vector that was
    # expanded to a matrix (by appending a singleton).

    # Since adjoint/transpose may not be False, we must make adjustments here.
    # The dim of b that holds the multiple equations.
    a_domain_sz_ = a.shape[-2 if adjoint_a or transpose_a else -1]
    b_eq_sz_ = b.shape[-2 if adjoint_b or transpose_b else -1]
    b_extra_sz_ = (np.prod(b.shape[:b_extra_ndims].as_list())
                   if b.shape[:b_extra_ndims].is_fully_defined() else None)
    if (a_domain_sz_ is not None and b_eq_sz_ is not None
            and b_extra_sz_ is not None):
        if b_extra_sz_ < 2 or a_domain_sz_ <= b_eq_sz_:
            return a, b, identity, still_need_to_transpose

    # At this point, we're flipping for sure!
    # Any transposes/adjoints will happen here explicitly, rather than in calling
    # code.  Why?  To avoid having to write separate complex code for each case.
    if adjoint_a:
        a = linalg.adjoint(a)
    elif transpose_a:
        a = linalg.transpose(a)
    if adjoint_b:
        b = linalg.adjoint(b)
    elif transpose_b:
        b = linalg.transpose(b)
    still_need_to_transpose = False

    # Recompute shapes, since the transpose/adjoint may have changed them.
    b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
    b_main_sh = array_ops.shape(b)[b_extra_ndims:]

    # Permutation to put the extra dims at the end.
    perm = (np.concatenate(
        (np.arange(b_extra_ndims, b.shape.ndims), np.arange(0, b_extra_ndims)),
        0))
    b_extra_on_end = array_ops.transpose(b, perm=perm)

    # Now squash this end into one long dim.
    b_squashed_end = array_ops.reshape(
        b_extra_on_end, array_ops.concat((b_main_sh[:-1], [-1]), 0))

    def reshape_inv(y):
        # Expand the extra dims hanging off the end, "b_extra_sh".
        # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y
        # Could have different batch dims than a and b, because of broadcasting.
        y_extra_shape = array_ops.concat(
            (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0)
        y_extra_on_end = array_ops.reshape(y, y_extra_shape)
        inverse_perm = np.argsort(perm)
        return array_ops.transpose(y_extra_on_end, perm=inverse_perm)

    return a, b_squashed_end, reshape_inv, still_need_to_transpose
    def _solve_matmul_internal(self,
                               x,
                               solve_matmul_fn,
                               adjoint=False,
                               adjoint_arg=False):
        # We heavily rely on Roth's column Lemma [1]:
        # (A x B) * vec X = vec BXA^T
        # where vec stacks all the columns of the matrix under each other.
        # In our case, we use a variant of the lemma that is row-major
        # friendly: (A x B) * vec' X = vec' AXB^T
        # Where vec' reshapes a matrix into a vector. We can repeatedly apply this
        # for a collection of kronecker products.
        # Given that (A x B)^-1 = A^-1 x B^-1 and (A x B)^T = A^T x B^T, we can
        # use the above to compute multiplications, solves with any composition of
        # transposes.
        output = x

        if adjoint_arg:
            if self.dtype.is_complex:
                output = math_ops.conj(output)
        else:
            output = linalg.transpose(output)

        for o in reversed(self.operators):
            # Statically compute the reshape.
            if adjoint:
                operator_dimension = o.range_dimension_tensor()
            else:
                operator_dimension = o.domain_dimension_tensor()
            output_shape = _prefer_static_shape(output)

            if tensor_util.constant_value(operator_dimension) is not None:
                operator_dimension = tensor_util.constant_value(
                    operator_dimension)
                if output.shape[-2] is not None and output.shape[
                        -1] is not None:
                    dim = int(output.shape[-2] * output_shape[-1] //
                              operator_dimension)
            else:
                dim = math_ops.cast(output_shape[-2] * output_shape[-1] //
                                    operator_dimension,
                                    dtype=dtypes.int32)

            output_shape = _prefer_static_concat_shape(
                output_shape[:-2], [dim, operator_dimension])
            output = array_ops.reshape(output, shape=output_shape)

            # Conjugate because we are trying to compute A @ B^T, but
            # `LinearOperator` only supports `adjoint_arg`.
            if self.dtype.is_complex:
                output = math_ops.conj(output)

            output = solve_matmul_fn(o,
                                     output,
                                     adjoint=adjoint,
                                     adjoint_arg=True)

        if adjoint_arg:
            col_dim = _prefer_static_shape(x)[-2]
        else:
            col_dim = _prefer_static_shape(x)[-1]

        if adjoint:
            row_dim = self.domain_dimension_tensor()
        else:
            row_dim = self.range_dimension_tensor()

        matrix_shape = [row_dim, col_dim]

        output = array_ops.reshape(
            output,
            _prefer_static_concat_shape(
                _prefer_static_shape(output)[:-2], matrix_shape))

        if x.shape.is_fully_defined():
            if adjoint_arg:
                column_dim = x.shape[-2]
            else:
                column_dim = x.shape[-1]
            broadcast_batch_shape = common_shapes.broadcast_shape(
                x.shape[:-2], self.batch_shape)
            if adjoint:
                matrix_dimensions = [self.domain_dimension, column_dim]
            else:
                matrix_dimensions = [self.range_dimension, column_dim]

            output.set_shape(
                broadcast_batch_shape.concatenate(matrix_dimensions))

        return output
def _reshape_for_efficiency(a,
                            b,
                            transpose_a=False,
                            transpose_b=False,
                            adjoint_a=False,
                            adjoint_b=False):
  """Maybe reshape a, b, and return an inverse map.  For matmul/solve."""
  def identity(x):
    return x

  # At this point, we have not taken transpose/adjoint of a/b.
  still_need_to_transpose = True

  if a.shape.ndims is None or b.shape.ndims is None:
    return a, b, identity, still_need_to_transpose

  # This could be handled in the future, but seems less common.
  if a.shape.ndims >= b.shape.ndims:
    return a, b, identity, still_need_to_transpose

  # From now on, we might modify b, but will not modify a.

  # Suppose:
  #   a.shape =     C + [m, n], b.shape =
  #   b.shape = S + C + [n, r]
  b_extra_ndims = b.shape.ndims - a.shape.ndims

  # b_extra_sh = S, b_main_sh = C + [n, r]
  b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
  b_main_sh = array_ops.shape(b)[b_extra_ndims:]

  # No reason to flip unless the extra dims of b are big enough.  Why?
  # Assume adjoint/transpose = False.  Then...
  # By not flipping, we have to replicate a to shape
  #   b_extra_sh + a.shape,
  # which could use extra memory.  But in all cases, the final output has shape
  #   b_extra_sh + a.shape[:-1] + [b.shape[-1]]
  # So we only end up creating a larger object if the end dim of b is smaller
  # than the end dim of a.  This often happens, e.g. if b was a vector that was
  # expanded to a matrix (by appending a singleton).

  # Since adjoint/transpose may not be False, we must make adjustments here.
  # The dim of b that holds the multiple equations.
  a_domain_sz_ = a.shape[-2 if adjoint_a or transpose_a else -1]
  b_eq_sz_ = b.shape[-2 if adjoint_b or transpose_b else -1]
  b_extra_sz_ = (
      np.prod(b.shape[:b_extra_ndims].as_list())
      if b.shape[:b_extra_ndims].is_fully_defined() else None)
  if (a_domain_sz_ is not None and b_eq_sz_ is not None and
      b_extra_sz_ is not None):
    if b_extra_sz_ < 2 or a_domain_sz_ <= b_eq_sz_:
      return a, b, identity, still_need_to_transpose

  # At this point, we're flipping for sure!
  # Any transposes/adjoints will happen here explicitly, rather than in calling
  # code.  Why?  To avoid having to write separate complex code for each case.
  if adjoint_a:
    a = linalg.adjoint(a)
  elif transpose_a:
    a = linalg.transpose(a)
  if adjoint_b:
    b = linalg.adjoint(b)
  elif transpose_b:
    b = linalg.transpose(b)
  still_need_to_transpose = False

  # Recompute shapes, since the transpose/adjoint may have changed them.
  b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
  b_main_sh = array_ops.shape(b)[b_extra_ndims:]

  # Permutation to put the extra dims at the end.
  perm = (
      array_ops.concat(
          (math_ops.range(b_extra_ndims, b.shape.ndims),
           math_ops.range(0, b_extra_ndims)), 0))
  b_extra_on_end = array_ops.transpose(b, perm=perm)

  # Now squash this end into one long dim.
  b_squashed_end = array_ops.reshape(
      b_extra_on_end, array_ops.concat((b_main_sh[:-1], [-1]), 0))

  def reshape_inv(y):
    # Expand the extra dims hanging off the end, "b_extra_sh".
    # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y
    # Could have different batch dims than a and b, because of broadcasting.
    y_extra_shape = array_ops.concat(
        (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0)
    y_extra_on_end = array_ops.reshape(y, y_extra_shape)
    return array_ops.transpose(
        y_extra_on_end, perm=array_ops.invert_permutation(perm))

  return a, b_squashed_end, reshape_inv, still_need_to_transpose