Beispiel #1
0
def eigh_abstract_eval(operand, lower):
    if isinstance(operand, ShapedArray):
        if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
            raise ValueError(
                "Argument to symmetric eigendecomposition must have shape [..., n, n],"
                "got shape {}".format(operand.shape))

        batch_dims = operand.shape[:-2]
        n = operand.shape[-1]
        v = ShapedArray(batch_dims + (n, n), operand.dtype)
        w = ShapedArray(batch_dims + (n, ),
                        lax_internal._complex_basetype(operand.dtype))
    else:
        v, w = operand, operand
    return v, w
Beispiel #2
0
def _empty_svd(a, *, full_matrices, compute_uv):
  batch_shape = a.shape[:-2]
  m, n = a.shape[-2:]
  s = jnp.empty(batch_shape + (0,), dtype=lax_internal._complex_basetype(a.dtype))
  if not compute_uv:
    return (s,)
  if full_matrices:
    size = max(m, n)
    u = jnp.broadcast_to(jnp.eye(size, dtype=a.dtype), batch_shape + (size, size))
  else:
    u = jnp.empty(batch_shape + (m, n), dtype=a.dtype)
  v = jnp.empty(batch_shape + (0, 0), dtype=a.dtype)
  if m < n:
    u, v = v, u
  return s, u, v
Beispiel #3
0
def svd_abstract_eval(operand, full_matrices, compute_uv):
  if isinstance(operand, ShapedArray):
    if operand.ndim < 2:
      raise ValueError("Argument to singular value decomposition must have ndims >= 2")

    batch_dims = operand.shape[:-2]
    m = operand.shape[-2]
    n = operand.shape[-1]
    s = ShapedArray(batch_dims + (min(m, n),),
                    lax_internal._complex_basetype(operand.dtype))
    if compute_uv:
      u = ShapedArray(batch_dims + (m, m if full_matrices else min(m, n)), operand.dtype)
      vt = ShapedArray(batch_dims + (n if full_matrices else min(m, n), n), operand.dtype)
      return s, u, vt
    else:
      return s,
  else:
    raise NotImplementedError