Пример #1
0
def _average(a,
             axis: Optional[Union[int, Tuple[int, ...]]] = None,
             weights=None,
             returned=False,
             keepdims=False):
    if weights is None:  # Treat all weights as 1
        _check_arraylike("average", a)
        a, = _promote_dtypes_inexact(a)
        avg = mean(a, axis=axis, keepdims=keepdims)
        if axis is None:
            weights_sum = lax.full((),
                                   core.dimension_as_value(a.size),
                                   dtype=avg.dtype)
        else:
            weights_sum = lax.full_like(avg,
                                        core.dimension_as_value(a.shape[axis]),
                                        dtype=avg.dtype)
    else:
        _check_arraylike("average", a, weights)
        a, weights = _promote_dtypes_inexact(a, weights)

        a_shape = np.shape(a)
        a_ndim = len(a_shape)
        weights_shape = np.shape(weights)
        axis = None if axis is None else _canonicalize_axis(axis, a_ndim)

        if a_shape != weights_shape:
            # Make sure the dimensions work out
            if axis is None:
                raise ValueError("Axis must be specified when shapes of a and "
                                 "weights differ.")
            if len(weights_shape) != 1:
                raise ValueError("1D weights expected when shapes of a and "
                                 "weights differ.")
            if not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]):
                raise ValueError("Length of weights not "
                                 "compatible with specified axis.")

            weights = _broadcast_to(weights,
                                    (a_ndim - 1) * (1, ) + weights_shape)
            weights = _moveaxis(weights, -1, axis)

        weights_sum = sum(weights, axis=axis, keepdims=keepdims)
        avg = sum(a * weights, axis=axis, keepdims=keepdims) / weights_sum

    if returned:
        if avg.shape != weights_sum.shape:
            weights_sum = _broadcast_to(weights_sum, avg.shape)
        return avg, weights_sum
    return avg
Пример #2
0
def _resize(image, shape: core.Shape, method: Union[str, ResizeMethod],
            antialias: bool, precision):
    if len(shape) != image.ndim:
        msg = (
            'shape must have length equal to the number of dimensions of x; '
            f' {shape} vs {image.shape}')
        raise ValueError(msg)
    if isinstance(method, str):
        method = ResizeMethod.from_string(method)
    if method == ResizeMethod.NEAREST:
        return _resize_nearest(image, shape)
    assert isinstance(method, ResizeMethod)
    kernel = _kernels[method]

    image, = _promote_dtypes_inexact(image)
    # Skip dimensions that have scale=1 and translation=0, this is only possible
    # since all of the current resize methods (kernels) are interpolating, so the
    # output = input under an identity warp.
    spatial_dims = tuple(
        i for i in range(len(shape))
        if not core.symbolic_equal_dim(image.shape[i], shape[i]))
    scale = [
        1.0 if core.symbolic_equal_dim(
            shape[d], 0) else core.dimension_as_value(shape[d]) /
        core.dimension_as_value(image.shape[d]) for d in spatial_dims
    ]
    return _scale_and_translate(image, shape, spatial_dims, scale,
                                [0.] * len(spatial_dims), kernel, antialias,
                                precision)
Пример #3
0
def matrix_power(a, n):
  a, = _promote_dtypes_inexact(jnp.asarray(a))

  if a.ndim < 2:
    raise TypeError("{}-dimensional array given. Array must be at least "
                    "two-dimensional".format(a.ndim))
  if a.shape[-2] != a.shape[-1]:
    raise TypeError("Last 2 dimensions of the array must be square")
  try:
    n = operator.index(n)
  except TypeError as err:
    raise TypeError("exponent must be an integer, got {}".format(n)) from err

  if n == 0:
    return jnp.broadcast_to(jnp.eye(a.shape[-2], dtype=a.dtype), a.shape)
  elif n < 0:
    a = inv(a)
    n = np.abs(n)

  if n == 1:
    return a
  elif n == 2:
    return a @ a
  elif n == 3:
    return (a @ a) @ a

  z = result = None
  while n > 0:
    z = a if z is None else (z @ z)
    n, bit = divmod(n, 2)
    if bit:
      result = z if result is None else (result @ z)

  return result
Пример #4
0
def polyval(p, x, *, unroll=16):
    _check_arraylike("polyval", p, x)
    p, x = _promote_dtypes_inexact(p, x)
    shape = lax.broadcast_shapes(p.shape[1:], x.shape)
    y = lax.full_like(x, 0, shape=shape, dtype=x.dtype)
    y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll)
    return y
Пример #5
0
def svd(a,
        full_matrices: bool = True,
        compute_uv: bool = True,
        hermitian: bool = False):
    a, = _promote_dtypes_inexact(jnp.asarray(a))
    if hermitian:
        w, v = lax_linalg.eigh(a)
        s = lax.abs(v)
        if compute_uv:
            sign = lax.sign(v)
            idxs = lax.broadcasted_iota(np.int64,
                                        s.shape,
                                        dimension=s.ndim - 1)
            s, idxs, sign = lax.sort((s, idxs, sign), dimension=-1, num_keys=1)
            s = lax.rev(s, dimensions=[s.ndim - 1])
            idxs = lax.rev(idxs, dimensions=[s.ndim - 1])
            sign = lax.rev(sign, dimensions=[s.ndim - 1])
            u = jnp.take_along_axis(w, idxs[..., None, :], axis=-1)
            vh = _H(u * sign[..., None, :])
            return u, s, vh
        else:
            return lax.rev(lax.sort(s, dimension=-1), dimensions=[s.ndim - 1])

    return lax_linalg.svd(a,
                          full_matrices=full_matrices,
                          compute_uv=compute_uv)
Пример #6
0
def hypot(x1, x2):
  _check_arraylike("hypot", x1, x2)
  x1, x2 = _promote_dtypes_inexact(x1, x2)
  x1 = lax.abs(x1)
  x2 = lax.abs(x2)
  x1, x2 = maximum(x1, x2), minimum(x1, x2)
  return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax_internal._ones(x1), x1)))))
Пример #7
0
def _solve_triangular(a, b, trans, lower, unit_diagonal):
    if trans == 0 or trans == "N":
        transpose_a, conjugate_a = False, False
    elif trans == 1 or trans == "T":
        transpose_a, conjugate_a = True, False
    elif trans == 2 or trans == "C":
        transpose_a, conjugate_a = True, True
    else:
        raise ValueError(f"Invalid 'trans' value {trans}")

    a, b = _promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))

    # lax_linalg.triangular_solve only supports matrix 'b's at the moment.
    b_is_vector = jnp.ndim(a) == jnp.ndim(b) + 1
    if b_is_vector:
        b = b[..., None]
    out = lax_linalg.triangular_solve(a,
                                      b,
                                      left_side=True,
                                      lower=lower,
                                      transpose_a=transpose_a,
                                      conjugate_a=conjugate_a,
                                      unit_diagonal=unit_diagonal)
    if b_is_vector:
        return out[..., 0]
    else:
        return out
Пример #8
0
def modf(x, out=None):
    _check_arraylike("modf", x)
    x, = _promote_dtypes_inexact(x)
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.modf is not supported.")
    whole = _where(lax.ge(x, lax_internal._zero(x)), floor(x), ceil(x))
    return x - whole, whole
Пример #9
0
def sinc(x):
    _check_arraylike("sinc", x)
    x, = _promote_dtypes_inexact(x)
    eq_zero = lax.eq(x, _lax_const(x, 0))
    pi_x = lax.mul(_lax_const(x, np.pi), x)
    safe_pi_x = _where(eq_zero, _lax_const(x, 1), pi_x)
    return _where(eq_zero, _sinc_maclaurin(0, pi_x),
                  lax.div(lax.sin(safe_pi_x), safe_pi_x))
Пример #10
0
def _roots_no_zeros(p):
    # assume: p does not have leading zeros and has length > 1
    p, = _promote_dtypes_inexact(p)

    # build companion matrix and find its eigenvalues (the roots)
    A = diag(ones((p.size - 2, ), p.dtype), -1)
    A = A.at[0, :].set(-p[1:] / p[0])
    roots = linalg.eigvals(A)
    return roots
Пример #11
0
def matrix_rank(M, tol=None):
  M, = _promote_dtypes_inexact(jnp.asarray(M))
  if M.ndim > 2:
    raise TypeError("array should have 2 or fewer dimensions")
  if M.ndim < 2:
    return jnp.any(M != 0).astype(jnp.int32)
  S = svd(M, full_matrices=False, compute_uv=False)
  if tol is None:
    tol = S.max() * np.max(M.shape) * jnp.finfo(S.dtype).eps
  return jnp.sum(S > tol)
Пример #12
0
def polyder(p, m=1):
  _check_arraylike("polyder", p)
  m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder")
  p, = _promote_dtypes_inexact(p)
  if m < 0:
    raise ValueError("Order of derivative must be positive")
  if m == 0:
    return p
  coeff = (arange(len(p), m, -1)[np.newaxis, :] - 1 - arange(m)[:, np.newaxis]).prod(0)
  return p[:-m] * coeff
Пример #13
0
def polymul(a1, a2, *, trim_leading_zeros=False):
    _check_arraylike("polymul", a1, a2)
    a1, a2 = _promote_dtypes_inexact(a1, a2)
    if trim_leading_zeros and (len(a1) > 1 or len(a2) > 1):
        a1, a2 = trim_zeros(a1, trim='f'), trim_zeros(a2, trim='f')
    if len(a1) == 0:
        a1 = asarray([0], dtype=a2.dtype)
    if len(a2) == 0:
        a2 = asarray([0], dtype=a1.dtype)
    return convolve(a1, a2, mode='full')
Пример #14
0
def eigh(a, UPLO=None, symmetrize_input=True):
  if UPLO is None or UPLO == "L":
    lower = True
  elif UPLO == "U":
    lower = False
  else:
    msg = "UPLO must be one of None, 'L', or 'U', got {}".format(UPLO)
    raise ValueError(msg)

  a, = _promote_dtypes_inexact(jnp.asarray(a))
  v, w = lax_linalg.eigh(a, lower=lower, symmetrize_input=symmetrize_input)
  return w, v
Пример #15
0
def qr(a, mode="reduced"):
  if mode in ("reduced", "r", "full"):
    full_matrices = False
  elif mode == "complete":
    full_matrices = True
  else:
    raise ValueError("Unsupported QR decomposition mode '{}'".format(mode))
  a, = _promote_dtypes_inexact(jnp.asarray(a))
  q, r = lax_linalg.qr(a, full_matrices)
  if mode == "r":
    return r
  return q, r
Пример #16
0
def det(a):
  a, = _promote_dtypes_inexact(jnp.asarray(a))
  a_shape = jnp.shape(a)
  if len(a_shape) >= 2 and a_shape[-1] == 2 and a_shape[-2] == 2:
    return _det_2x2(a)
  elif len(a_shape) >= 2 and a_shape[-1] == 3 and a_shape[-2] == 3:
    return _det_3x3(a)
  elif len(a_shape) >= 2 and a_shape[-1] == a_shape[-2]:
    sign, logdet = slogdet(a)
    return sign * jnp.exp(logdet)
  else:
    msg = "Argument to _det() must have shape [..., n, n], got {}"
    raise ValueError(msg.format(a_shape))
Пример #17
0
def slogdet(a, *, method: Optional[str] = None):
  a, = _promote_dtypes_inexact(jnp.asarray(a))
  a_shape = jnp.shape(a)
  if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
    msg = "Argument to slogdet() must have shape [..., n, n], got {}"
    raise ValueError(msg.format(a_shape))

  if method is None or method == "lu":
    return _slogdet_lu(a)
  elif method == "qr":
    return _slogdet_qr(a)
  else:
    raise ValueError(f"Unknown slogdet method '{method}'. Supported methods "
                     "are 'lu' (`None`), and 'qr'.")
Пример #18
0
def _lu(a, permute_l):
    a, = _promote_dtypes_inexact(jnp.asarray(a))
    lu, pivots, permutation = lax_linalg.lu(a)
    dtype = lax.dtype(a)
    m, n = jnp.shape(a)
    p = jnp.real(
        jnp.array(permutation[None, :] == jnp.arange(m)[:, None], dtype=dtype))
    k = min(m, n)
    l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype)
    u = jnp.triu(lu)[:k, :]
    if permute_l:
        return jnp.matmul(p, l), u
    else:
        return p, l, u
Пример #19
0
def qr(a, mode="reduced"):
    a, = _promote_dtypes_inexact(jnp.asarray(a))
    if mode == "raw":
        a, taus = lax_linalg.geqrf(a)
        return _T(a), taus
    if mode in ("reduced", "r", "full"):
        full_matrices = False
    elif mode == "complete":
        full_matrices = True
    else:
        raise ValueError(f"Unsupported QR decomposition mode '{mode}'")
    q, r = lax_linalg.qr(a, full_matrices=full_matrices)
    if mode == "r":
        return r
    return q, r
Пример #20
0
def vq(obs, code_book, check_finite=True):
    _check_arraylike("scipy.cluster.vq.vq", obs, code_book)
    if obs.ndim != code_book.ndim:
        raise ValueError("Observation and code_book should have the same rank")
    obs, code_book = _promote_dtypes_inexact(obs, code_book)
    if obs.ndim == 1:
        obs, code_book = obs[..., None], code_book[..., None]
    if obs.ndim != 2:
        raise ValueError("ndim different than 1 or 2 are not supported")

    # explicitly rank promotion
    dist = vmap(lambda ob: norm(ob[None] - code_book, axis=-1))(obs)
    code = argmin(dist, axis=-1)
    dist_min = vmap(operator.getitem)(dist, code)
    return code, dist_min
Пример #21
0
def _qr(a, mode, pivoting):
    if pivoting:
        raise NotImplementedError(
            "The pivoting=True case of qr is not implemented.")
    if mode in ("full", "r"):
        full_matrices = True
    elif mode == "economic":
        full_matrices = False
    else:
        raise ValueError(f"Unsupported QR decomposition mode '{mode}'")
    a, = _promote_dtypes_inexact(jnp.asarray(a))
    q, r = lax_linalg.qr(a, full_matrices=full_matrices)
    if mode == "r":
        return (r, )
    return q, r
Пример #22
0
def polydiv(u, v, *, trim_leading_zeros=False):
    _check_arraylike("polydiv", u, v)
    u, v = _promote_dtypes_inexact(u, v)
    m = len(u) - 1
    n = len(v) - 1
    scale = 1. / v[0]
    q = zeros(max(m - n + 1, 1), dtype=u.dtype)  # force same dtype
    for k in range(0, m - n + 1):
        d = scale * u[k]
        q = q.at[k].set(d)
        u = u.at[k:k + n + 1].add(-d * v)
    if trim_leading_zeros:
        # use the square root of finfo(dtype) to approximate the absolute tolerance used in numpy
        return q, trim_zeros_tol(u, tol=sqrt(finfo(u.dtype).eps), trim='f')
    else:
        return q, u
Пример #23
0
def _cho_solve(c, b, lower):
    c, b = _promote_dtypes_inexact(jnp.asarray(c), jnp.asarray(b))
    lax_linalg._check_solve_shapes(c, b)
    b = lax_linalg.triangular_solve(c,
                                    b,
                                    left_side=True,
                                    lower=lower,
                                    transpose_a=not lower,
                                    conjugate_a=not lower)
    b = lax_linalg.triangular_solve(c,
                                    b,
                                    left_side=True,
                                    lower=lower,
                                    transpose_a=lower,
                                    conjugate_a=lower)
    return b
Пример #24
0
def polyint(p, m=1, k=None):
  m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint")
  k = 0 if k is None else k
  _check_arraylike("polyint", p, k)
  p, k = _promote_dtypes_inexact(p, k)
  if m < 0:
    raise ValueError("Order of integral must be positive (see polyder)")
  k = atleast_1d(k)
  if len(k) == 1:
    k = full((m,), k[0])
  if k.shape != (m,):
    raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.")
  if m == 0:
    return p
  else:
    coeff = maximum(1, arange(len(p) + m, 0, -1)[np.newaxis, :] - 1 - arange(m)[:, np.newaxis]).prod(0)
    return true_divide(concatenate((p, k)), coeff)
Пример #25
0
def roots(p, *, strip_zeros=True):
    _check_arraylike("roots", p)
    p = atleast_1d(*_promote_dtypes_inexact(p))
    if p.ndim != 1:
        raise ValueError("Input must be a rank-1 array.")
    if p.size < 2:
        return array([], dtype=dtypes._to_complex_dtype(p.dtype))
    num_leading_zeros = _where(all(p == 0), len(p), argmin(p == 0))

    if strip_zeros:
        num_leading_zeros = core.concrete_or_error(
            int, num_leading_zeros,
            "The error occurred in the jnp.roots() function. To use this within a "
            "JIT-compiled context, pass strip_zeros=False, but be aware that leading zeros "
            "will be result in some returned roots being set to NaN.")
        return _roots_no_zeros(p[num_leading_zeros:])
    else:
        return _roots_with_zeros(p, num_leading_zeros)
Пример #26
0
def _eigh(a, b, lower, eigvals_only, eigvals, type):
    if b is not None:
        raise NotImplementedError(
            "Only the b=None case of eigh is implemented")
    if type != 1:
        raise NotImplementedError(
            "Only the type=1 case of eigh is implemented.")
    if eigvals is not None:
        raise NotImplementedError(
            "Only the eigvals=None case of eigh is implemented.")

    a, = _promote_dtypes_inexact(jnp.asarray(a))
    v, w = lax_linalg.eigh(a, lower=lower)

    if eigvals_only:
        return w
    else:
        return w, v
Пример #27
0
def frexp(x):
    _check_arraylike("frexp", x)
    x, = _promote_dtypes_inexact(x)
    if dtypes.issubdtype(x.dtype, np.complexfloating):
        raise TypeError("frexp does not support complex-valued inputs")

    dtype = dtypes.dtype(x)
    info = dtypes.finfo(dtype)
    mask = (1 << info.nexp) - 1
    bias = ((1 << info.nexp) - 1) >> 1

    x1, x2 = _normalize_float(x)
    x2 += ((x1 >> info.nmant) & mask) - bias + 1
    x1 &= ~(mask << info.nmant)
    x1 |= (bias - 1) << info.nmant
    x1 = lax.bitcast_convert_type(x1, dtype)

    cond = isinf(x) | isnan(x) | (x == 0)
    x2 = _where(cond, lax_internal._zeros(x2), x2)
    return _where(cond, x, x1), lax.convert_element_type(x2, np.int32)
Пример #28
0
def fft(x, fft_type: Union[xla_client.FftType, str],
        fft_lengths: Sequence[int]):
    if isinstance(fft_type, str):
        typ = _str_to_fft_type(fft_type)
    elif isinstance(fft_type, xla_client.FftType):
        typ = fft_type
    else:
        raise TypeError(f"Unknown FFT type value '{fft_type}'")

    if typ == xla_client.FftType.RFFT:
        if np.iscomplexobj(x):
            raise ValueError("only real valued inputs supported for rfft")
        x, = _promote_dtypes_inexact(x)
    else:
        x, = _promote_dtypes_complex(x)
    if len(fft_lengths) == 0:
        # XLA FFT doesn't support 0-rank.
        return x
    fft_lengths = tuple(fft_lengths)
    return fft_p.bind(x, fft_type=typ, fft_lengths=fft_lengths)
Пример #29
0
def _solve(a, b, sym_pos, lower):
    if not sym_pos:
        return np_linalg.solve(a, b)

    a, b = _promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
    lax_linalg._check_solve_shapes(a, b)

    # With custom_linear_solve, we can reuse the same factorization when
    # computing sensitivities. This is considerably faster.
    factors = cho_factor(lax.stop_gradient(a), lower=lower)
    custom_solve = partial(lax.custom_linear_solve,
                           lambda x: lax_linalg._matvec_multiply(a, x),
                           solve=lambda _, x: cho_solve(factors, x),
                           symmetric=True)
    if a.ndim == b.ndim + 1:
        # b.shape == [..., m]
        return custom_solve(b)
    else:
        # b.shape == [..., m, k]
        return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
Пример #30
0
def initial_step_size(fun, t0, y0, order, rtol, atol, f0):
  # Algorithm from:
  # E. Hairer, S. P. Norsett G. Wanner,
  # Solving Ordinary Differential Equations I: Nonstiff Problems, Sec. II.4.
  y0, f0 = _promote_dtypes_inexact(y0, f0)
  dtype = y0.dtype

  scale = atol + jnp.abs(y0) * rtol
  d0 = jnp.linalg.norm(y0 / scale.astype(dtype))
  d1 = jnp.linalg.norm(f0 / scale.astype(dtype))

  h0 = jnp.where((d0 < 1e-5) | (d1 < 1e-5), 1e-6, 0.01 * d0 / d1)
  y1 = y0 + h0.astype(dtype) * f0
  f1 = fun(y1, t0 + h0)
  d2 = jnp.linalg.norm((f1 - f0) / scale.astype(dtype)) / h0

  h1 = jnp.where((d1 <= 1e-15) & (d2 <= 1e-15),
                jnp.maximum(1e-6, h0 * 1e-3),
                (0.01 / jnp.max(d1 + d2)) ** (1. / (order + 1.)))

  return jnp.minimum(100. * h0, h1)