コード例 #1
0
def zeta(x, q=None):
    assert q is not None, "Riemann zeta function is not implemented yet."
    # Reference: Johansson, Fredrik.
    # "Rigorous high-precision computation of the Hurwitz zeta function and its derivatives."
    # Numerical Algorithms 69.2 (2015): 253-270.
    # https://arxiv.org/abs/1309.2877 - formula (5)
    # here we keep the same notation as in reference
    s, a = _promote_args_inexact("zeta", x, q)
    dtype = lax.dtype(a).type
    s_, a_ = jnp.expand_dims(s, -1), jnp.expand_dims(a, -1)
    # precision ~ N, M
    N = M = dtype(8) if lax.dtype(a) == jnp.float32 else dtype(16)
    assert M <= len(_BERNOULLI_COEFS)
    k = jnp.expand_dims(np.arange(N, dtype=N.dtype), tuple(range(a.ndim)))
    S = jnp.sum((a_ + k)**-s_, -1)
    I = lax.div((a + N)**(dtype(1) - s), s - dtype(1))
    T0 = (a + N)**-s
    m = jnp.expand_dims(np.arange(2 * M, dtype=M.dtype), tuple(range(s.ndim)))
    s_over_a = (s_ + m) / (a_ + N)
    T1 = jnp.cumprod(s_over_a, -1)[..., ::2]
    T1 = jnp.clip(T1, a_max=jnp.finfo(dtype).max)
    coefs = np.expand_dims(
        np.array(_BERNOULLI_COEFS[:T1.shape[-1]], dtype=dtype),
        tuple(range(a.ndim)))
    T1 = T1 / coefs
    T = T0 * (dtype(0.5) + T1.sum(-1))
    return S + I + T
コード例 #2
0
def multigammaln(a, d):
  d = core.concrete_or_error(int, d, "d argument of multigammaln")
  a, d_ = _promote_args_inexact("multigammaln", a, d)

  constant = lax.mul(lax.mul(lax.mul(_constant_like(a, 0.25), d_),
                             lax.sub(d_, _constant_like(a, 1))),
                     lax.log(_constant_like(a, np.pi)))
  b = lax.div(jnp.arange(d, dtype=d_.dtype), _constant_like(a, 2))
  res = jnp.sum(gammaln(jnp.expand_dims(a, axis=-1) -
                        jnp.expand_dims(b, axis=tuple(range(a.ndim)))),
                axis=-1)
  return res + constant
コード例 #3
0
ファイル: index_tricks.py プロジェクト: xueeinstein/jax
  def __getitem__(self, key):
    if not isinstance(key, tuple):
      key = (key,)

    params = [self.axis, self.ndmin, self.trans1d, -1]

    if isinstance(key[0], str):
      # split off the directive
      directive, *key = key  # pytype: disable=bad-unpacking
      # check two special cases: matrix directives
      if directive == "r":
        params[-1] = 0
      elif directive == "c":
        params[-1] = 1
      else:
        vec = directive.split(",")
        k = len(vec)
        if k < 4:
          vec += params[k:]
        else:
          # ignore everything after the first three comma-separated ints
          vec = vec[:3] + params[-1]
        try:
          params = list(map(int, vec))
        except ValueError as err:
          raise ValueError(
            f"could not understand directive {directive!r}"
          ) from err

    axis, ndmin, trans1d, matrix = params

    output = []
    for item in key:
      if isinstance(item, slice):
        newobj = _make_1d_grid_from_slice(item, op_name=self.op_name)
      elif isinstance(item, str):
        raise ValueError("string directive must be placed at the beginning")
      else:
        newobj = item

      newobj = array(newobj, copy=False, ndmin=ndmin)

      if trans1d != -1 and ndmin - np.ndim(item) > 0:
        shape_obj = list(range(ndmin))
        # Calculate number of left shifts, with overflow protection by mod
        num_lshifts = ndmin - abs(ndmin + trans1d + 1) % ndmin
        shape_obj = tuple(shape_obj[num_lshifts:] + shape_obj[:num_lshifts])

        newobj = transpose(newobj, shape_obj)

      output.append(newobj)

    res = concatenate(tuple(output), axis=axis)

    if matrix != -1 and res.ndim == 1:
      # insert 2nd dim at axis 0 or 1
      res = expand_dims(res, matrix)

    return res
コード例 #4
0
ファイル: linalg.py プロジェクト: yuejiesong1900/jax
def triangular_solve(a,
                     b,
                     left_side: bool = False,
                     lower: bool = False,
                     transpose_a: bool = False,
                     conjugate_a: bool = False,
                     unit_diagonal: bool = False):
    r"""Triangular solve.

  Solves either the matrix equation

  .. math::
    \mathit{op}(A) . X = B

  if ``left_side`` is ``True`` or

  .. math::
    X . \mathit{op}(A) = B

  if ``left_side`` is ``False``.

  ``A`` must be a lower or upper triangular square matrix, and where
  :math:`\mathit{op}(A)` may either transpose :math:`A` if ``transpose_a``
  is ``True`` and/or take its complex conjugate if ``conjugate_a`` is ``True``.

  Args:
    a: A batch of matrices with shape ``[..., m, m]``.
    b: A batch of matrices with shape ``[..., m, n]`` if ``left_side`` is
      ``True`` or shape ``[..., n, m]`` otherwise.
    left_side: describes which of the two matrix equations to solve; see above.
    lower: describes which triangle of ``a`` should be used. The other triangle
      is ignored.
    transpose_a: if ``True``, the value of ``a`` is transposed.
    conjugate_a: if ``True``, the complex conjugate of ``a`` is used in the
      solve. Has no effect if ``a`` is real.
    unit_diagonal: if ``True``, the diagonal of ``a`` is assumed to be unit
      (all 1s) and not accessed.

  Returns:
    A batch of matrices the same shape and dtype as ``b``.
  """
    conjugate_a = conjugate_a and jnp.issubdtype(lax.dtype(a),
                                                 jnp.complexfloating)
    singleton = jnp.ndim(b) == jnp.ndim(a) - 1
    if singleton:
        b = jnp.expand_dims(b, -1 if left_side else -2)
    out = triangular_solve_p.bind(a,
                                  b,
                                  left_side=left_side,
                                  lower=lower,
                                  transpose_a=transpose_a,
                                  conjugate_a=conjugate_a,
                                  unit_diagonal=unit_diagonal)
    if singleton:
        out = out[..., 0] if left_side else out[..., 0, :]
    return out
コード例 #5
0
ファイル: special.py プロジェクト: xeransis/jax
def multigammaln(a, d):
    a, = _promote_args_inexact("multigammaln", a)
    d = lax.convert_element_type(d, lax.dtype(a))
    constant = lax.mul(
        lax.mul(lax.mul(_constant_like(a, 0.25), d),
                lax.sub(d, _constant_like(a, 1))),
        lax.log(_constant_like(a, np.pi)))
    res = jnp.sum(gammaln(
        jnp.expand_dims(a, axis=-1) -
        lax.div(jnp.arange(d), _constant_like(a, 2))),
                  axis=-1)
    return res + constant
コード例 #6
0
ファイル: linalg.py プロジェクト: stevenchang8/jax
def triangular_solve(a, b, left_side=False, lower=False, transpose_a=False,
                     conjugate_a=False, unit_diagonal=False):
  conjugate_a = conjugate_a and jnp.issubdtype(lax.dtype(a), jnp.complexfloating)
  singleton = jnp.ndim(b) == jnp.ndim(a) - 1
  if singleton:
    b = jnp.expand_dims(b, -1 if left_side else -2)
  out = triangular_solve_p.bind(
      a, b, left_side=left_side, lower=lower, transpose_a=transpose_a,
      conjugate_a=conjugate_a, unit_diagonal=unit_diagonal)
  if singleton:
    out = out[..., 0] if left_side else out[..., 0, :]
  return out
コード例 #7
0
ファイル: linalg.py プロジェクト: jbampton/jax
def norm(x,
         ord=None,
         axis: Union[None, Tuple[int, ...], int] = None,
         keepdims=False):
    x = _promote_arg_dtypes(jnp.asarray(x))
    x_shape = jnp.shape(x)
    ndim = len(x_shape)

    if axis is None:
        # NumPy has an undocumented behavior that admits arbitrary rank inputs if
        # `ord` is None: https://github.com/numpy/numpy/issues/14215
        if ord is None:
            return jnp.sqrt(
                jnp.sum(jnp.real(x * jnp.conj(x)), keepdims=keepdims))
        axis = tuple(range(ndim))
    elif isinstance(axis, tuple):
        axis = tuple(canonicalize_axis(x, ndim) for x in axis)
    else:
        axis = (canonicalize_axis(axis, ndim), )

    num_axes = len(axis)
    if num_axes == 1:
        if ord is None or ord == 2:
            return jnp.sqrt(
                jnp.sum(jnp.real(x * jnp.conj(x)),
                        axis=axis,
                        keepdims=keepdims))
        elif ord == jnp.inf:
            return jnp.amax(jnp.abs(x), axis=axis, keepdims=keepdims)
        elif ord == -jnp.inf:
            return jnp.amin(jnp.abs(x), axis=axis, keepdims=keepdims)
        elif ord == 0:
            return jnp.sum(x != 0,
                           dtype=jnp.finfo(lax.dtype(x)).dtype,
                           axis=axis,
                           keepdims=keepdims)
        elif ord == 1:
            # Numpy has a special case for ord == 1 as an optimization. We don't
            # really need the optimization (XLA could do it for us), but the Numpy
            # code has slightly different type promotion semantics, so we need a
            # special case too.
            return jnp.sum(jnp.abs(x), axis=axis, keepdims=keepdims)
        else:
            abs_x = jnp.abs(x)
            ord = lax_internal._const(abs_x, ord)
            out = jnp.sum(abs_x**ord, axis=axis, keepdims=keepdims)
            return jnp.power(out, 1. / ord)

    elif num_axes == 2:
        row_axis, col_axis = cast(Tuple[int, ...], axis)
        if ord is None or ord in ('f', 'fro'):
            return jnp.sqrt(
                jnp.sum(jnp.real(x * jnp.conj(x)),
                        axis=axis,
                        keepdims=keepdims))
        elif ord == 1:
            if not keepdims and col_axis > row_axis:
                col_axis -= 1
            return jnp.amax(jnp.sum(jnp.abs(x),
                                    axis=row_axis,
                                    keepdims=keepdims),
                            axis=col_axis,
                            keepdims=keepdims)
        elif ord == -1:
            if not keepdims and col_axis > row_axis:
                col_axis -= 1
            return jnp.amin(jnp.sum(jnp.abs(x),
                                    axis=row_axis,
                                    keepdims=keepdims),
                            axis=col_axis,
                            keepdims=keepdims)
        elif ord == jnp.inf:
            if not keepdims and row_axis > col_axis:
                row_axis -= 1
            return jnp.amax(jnp.sum(jnp.abs(x),
                                    axis=col_axis,
                                    keepdims=keepdims),
                            axis=row_axis,
                            keepdims=keepdims)
        elif ord == -jnp.inf:
            if not keepdims and row_axis > col_axis:
                row_axis -= 1
            return jnp.amin(jnp.sum(jnp.abs(x),
                                    axis=col_axis,
                                    keepdims=keepdims),
                            axis=row_axis,
                            keepdims=keepdims)
        elif ord in ('nuc', 2, -2):
            x = jnp.moveaxis(x, axis, (-2, -1))
            if ord == 2:
                reducer = jnp.amax
            elif ord == -2:
                reducer = jnp.amin
            else:
                reducer = jnp.sum
            y = reducer(svd(x, compute_uv=False), axis=-1)
            if keepdims:
                y = jnp.expand_dims(y, axis)
            return y
        else:
            raise ValueError("Invalid order '{}' for matrix norm.".format(ord))
    else:
        raise ValueError(
            "Invalid axis values ({}) for jnp.linalg.norm.".format(axis))