Ejemplo n.º 1
0
 def reshape_inv(y):
   # Expand the extra dims hanging off the end, "b_extra_sh".
   # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y
   # Could have different batch dims than a and b, because of broadcasting.
   y_extra_shape = array_ops.concat(
       (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0)
   y_extra_on_end = array_ops.reshape(y, y_extra_shape)
   inverse_perm = np.argsort(perm)
   return array_ops.transpose(y_extra_on_end, perm=inverse_perm)
def _rotate_last_dim(x, rotate_right=False):
    """Rotate the last dimension either left or right."""
    ndims = array_ops.rank(x)
    if rotate_right:
        transpose_perm = array_ops.concat(
            [[ndims - 1], array_ops.range(0, ndims - 1)], axis=0)
    else:
        transpose_perm = array_ops.concat([array_ops.range(1, ndims), [0]],
                                          axis=0)
    return array_ops.transpose(x, transpose_perm)
Ejemplo n.º 3
0
def _reshape_for_efficiency(a,
                            b,
                            transpose_a=False,
                            transpose_b=False,
                            adjoint_a=False,
                            adjoint_b=False):
  """Maybe reshape a, b, and return an inverse map.  For matmul/solve."""
  def identity(x):
    return x

  # At this point, we have not taken transpose/adjoint of a/b.
  still_need_to_transpose = True

  if tensor_shape.TensorShape(a.shape).ndims is None or tensor_shape.TensorShape(b.shape).ndims is None:
    return a, b, identity, still_need_to_transpose

  # This could be handled in the future, but seems less common.
  if tensor_shape.TensorShape(a.shape).ndims >= tensor_shape.TensorShape(b.shape).ndims:
    return a, b, identity, still_need_to_transpose

  # From now on, we might modify b, but will not modify a.

  # Suppose:
  #   tensor_shape.TensorShape(a.shape) =     C + [m, n], tensor_shape.TensorShape(b.shape) =
  #   tensor_shape.TensorShape(b.shape) = S + C + [n, r]
  b_extra_ndims = tensor_shape.TensorShape(b.shape).ndims - tensor_shape.TensorShape(a.shape).ndims

  # b_extra_sh = S, b_main_sh = C + [n, r]
  b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
  b_main_sh = array_ops.shape(b)[b_extra_ndims:]

  # No reason to flip unless the extra dims of b are big enough.  Why?
  # Assume adjoint/transpose = False.  Then...
  # By not flipping, we have to replicate a to shape
  #   b_extra_sh + tensor_shape.TensorShape(a.shape),
  # which could use extra memory.  But in all cases, the final output has shape
  #   b_extra_sh + tensor_shape.TensorShape(a.shape)[:-1] + tensor_shape.TensorShape([b.shape)[-1]]
  # So we only end up creating a larger object if the end dim of b is smaller
  # than the end dim of a.  This often happens, e.g. if b was a vector that was
  # expanded to a matrix (by appending a singleton).

  # Since adjoint/transpose may not be False, we must make adjustments here.
  # The dim of b that holds the multiple equations.
  a_domain_sz_ = tensor_shape.TensorShape(a.shape)[-2 if adjoint_a or transpose_a else -1]
  b_eq_sz_ = tensor_shape.TensorShape(b.shape)[-2 if adjoint_b or transpose_b else -1]
  b_extra_sz_ = (
      np.prod(tensor_shape.TensorShape(b.shape)[:b_extra_ndims].as_list())
      if tensor_shape.TensorShape(b.shape)[:b_extra_ndims].is_fully_defined() else None)
  if (a_domain_sz_ is not None and b_eq_sz_ is not None and
      b_extra_sz_ is not None):
    if b_extra_sz_ < 2 or a_domain_sz_ <= b_eq_sz_:
      return a, b, identity, still_need_to_transpose

  # At this point, we're flipping for sure!
  # Any transposes/adjoints will happen here explicitly, rather than in calling
  # code.  Why?  To avoid having to write separate complex code for each case.
  if adjoint_a:
    a = _linalg.matrix_transpose(a, conjugate=True)
  elif transpose_a:
    a = _linalg.matrix_transpose(a, conjugate=False)
  if adjoint_b:
    b = _linalg.matrix_transpose(b, conjugate=True)
  elif transpose_a:
    b = _linalg.matrix_transpose(b, conjugate=False)
  still_need_to_transpose = False

  # Recompute shapes, since the transpose/adjoint may have changed them.
  b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
  b_main_sh = array_ops.shape(b)[b_extra_ndims:]

  # Permutation to put the extra dims at the end.
  perm = (
      np.concatenate(
          (np.arange(b_extra_ndims, tensor_shape.TensorShape(b.shape).ndims),
           np.arange(0, b_extra_ndims)), 0))
  b_extra_on_end = array_ops.transpose(b, perm=perm)

  # Now squash this end into one long dim.
  b_squashed_end = array_ops.reshape(
      b_extra_on_end, array_ops.concat((b_main_sh[:-1], [-1]), 0))

  def reshape_inv(y):
    # Expand the extra dims hanging off the end, "b_extra_sh".
    # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y
    # Could have different batch dims than a and b, because of broadcasting.
    y_extra_shape = array_ops.concat(
        (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0)
    y_extra_on_end = array_ops.reshape(y, y_extra_shape)
    inverse_perm = np.argsort(perm)
    return array_ops.transpose(y_extra_on_end, perm=inverse_perm)

  return a, b_squashed_end, reshape_inv, still_need_to_transpose