コード例 #1
0
def convolve(in1, in2, mode='full', method='auto', precision=None):
    if method != 'auto':
        warnings.warn("convolve() ignores method argument")
    if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype(
            in2.dtype, jnp.complexfloating):
        raise NotImplementedError("convolve() does not support complex inputs")
    return _convolve_nd(in1, in2, mode, precision=precision)
コード例 #2
0
def correlate(in1, in2, mode='full', method='auto', precision=None):
    if method != 'auto':
        warnings.warn("correlate() ignores method argument")
    if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype(
            in2.dtype, jnp.complexfloating):
        raise NotImplementedError(
            "correlate() does not support complex inputs")
    if jnp.ndim(in1) != 1 or jnp.ndim(in2) != 1:
        raise ValueError(
            "correlate() only supports {ndim}-dimensional inputs.")
    return _convolve_nd(in1, in2[::-1], mode, precision=precision)
コード例 #3
0
ファイル: linalg.py プロジェクト: yuejiesong1900/jax
    def body(k, state):
        pivot, perm, a = state
        m_idx = jnp.arange(m)
        n_idx = jnp.arange(n)

        if jnp.issubdtype(a.dtype, jnp.complexfloating):
            t = a[:, k]
            magnitude = jnp.abs(jnp.real(t)) + jnp.abs(jnp.imag(t))
        else:
            magnitude = jnp.abs(a[:, k])
        i = jnp.argmax(jnp.where(m_idx >= k, magnitude, -jnp.inf))
        pivot = ops.index_update(pivot, ops.index[k], i)

        a = ops.index_update(a, ops.index[[k, i], ], a[[i, k], ])

        perm = ops.index_update(perm, ops.index[[i, k], ], perm[[k, i], ])

        # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes
        x = a[k, k]
        a = ops.index_update(a, ops.index[:, k],
                             jnp.where(m_idx > k, a[:, k] / x, a[:, k]))

        # a[k+1:, k+1:] -= jnp.outer(a[k+1:, k], a[k, k+1:])
        a = a - jnp.where(
            (m_idx[:, None] > k) & (n_idx > k), jnp.outer(a[:, k], a[k, :]),
            jnp.array(0, dtype=a.dtype))
        return pivot, perm, a
コード例 #4
0
def _get_identity(op, dtype):
    """Get an appropriate identity for a given operation in a given dtype."""
    if op is lax.scatter_add:
        return 0
    elif op is lax.scatter_mul:
        return 1
    elif op is lax.scatter_min:
        if jnp.issubdtype(dtype, jnp.integer):
            return jnp.iinfo(dtype).max
        return float('inf')
    elif op is lax.scatter_max:
        if jnp.issubdtype(dtype, jnp.integer):
            return jnp.iinfo(dtype).min
        return -float('inf')
    else:
        raise ValueError(f"Unrecognized op: {op}")
コード例 #5
0
def convolve2d(in1,
               in2,
               mode='full',
               boundary='fill',
               fillvalue=0,
               precision=None):
    if boundary != 'fill' or fillvalue != 0:
        raise NotImplementedError(
            "convolve2d() only supports boundary='fill', fillvalue=0")
    if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype(
            in2.dtype, jnp.complexfloating):
        raise NotImplementedError(
            "convolve2d() does not support complex inputs")
    if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2:
        raise ValueError("convolve2d() only supports 2-dimensional inputs.")
    return _convolve_nd(in1, in2, mode, precision=precision)
コード例 #6
0
ファイル: linalg.py プロジェクト: yuejiesong1900/jax
def _nan_like(c, operand):
    shape = c.get_shape(operand)
    dtype = shape.element_type()
    if jnp.issubdtype(dtype, np.complexfloating):
        nan = xb.constant(c, np.array(np.nan * (1. + 1j), dtype=dtype))
    else:
        nan = xb.constant(c, np.array(np.nan, dtype=dtype))
    return xops.Broadcast(nan, shape.dimensions())
コード例 #7
0
ファイル: linalg.py プロジェクト: yuejiesong1900/jax
def triangular_solve(a,
                     b,
                     left_side: bool = False,
                     lower: bool = False,
                     transpose_a: bool = False,
                     conjugate_a: bool = False,
                     unit_diagonal: bool = False):
    r"""Triangular solve.

  Solves either the matrix equation

  .. math::
    \mathit{op}(A) . X = B

  if ``left_side`` is ``True`` or

  .. math::
    X . \mathit{op}(A) = B

  if ``left_side`` is ``False``.

  ``A`` must be a lower or upper triangular square matrix, and where
  :math:`\mathit{op}(A)` may either transpose :math:`A` if ``transpose_a``
  is ``True`` and/or take its complex conjugate if ``conjugate_a`` is ``True``.

  Args:
    a: A batch of matrices with shape ``[..., m, m]``.
    b: A batch of matrices with shape ``[..., m, n]`` if ``left_side`` is
      ``True`` or shape ``[..., n, m]`` otherwise.
    left_side: describes which of the two matrix equations to solve; see above.
    lower: describes which triangle of ``a`` should be used. The other triangle
      is ignored.
    transpose_a: if ``True``, the value of ``a`` is transposed.
    conjugate_a: if ``True``, the complex conjugate of ``a`` is used in the
      solve. Has no effect if ``a`` is real.
    unit_diagonal: if ``True``, the diagonal of ``a`` is assumed to be unit
      (all 1s) and not accessed.

  Returns:
    A batch of matrices the same shape and dtype as ``b``.
  """
    conjugate_a = conjugate_a and jnp.issubdtype(lax.dtype(a),
                                                 jnp.complexfloating)
    singleton = jnp.ndim(b) == jnp.ndim(a) - 1
    if singleton:
        b = jnp.expand_dims(b, -1 if left_side else -2)
    out = triangular_solve_p.bind(a,
                                  b,
                                  left_side=left_side,
                                  lower=lower,
                                  transpose_a=transpose_a,
                                  conjugate_a=conjugate_a,
                                  unit_diagonal=unit_diagonal)
    if singleton:
        out = out[..., 0] if left_side else out[..., 0, :]
    return out
コード例 #8
0
ファイル: linalg.py プロジェクト: ahoenselaar/jax
def _promote_arg_dtypes(*args):
    """Promotes `args` to a common inexact type."""
    dtype, weak_type = dtypes._lattice_result_type(*args)
    if not jnp.issubdtype(dtype, jnp.inexact):
        dtype, weak_type = jnp.float_, False
    dtype = dtypes.canonicalize_dtype(dtype)
    args = [lax._convert_element_type(arg, dtype, weak_type) for arg in args]
    if len(args) == 1:
        return args[0]
    else:
        return args
コード例 #9
0
ファイル: linalg.py プロジェクト: ahoenselaar/jax
def _slogdet_jvp(primals, tangents):
    x, = primals
    g, = tangents
    sign, ans = slogdet(x)
    ans_dot = jnp.trace(solve(x, g), axis1=-1, axis2=-2)
    if jnp.issubdtype(jnp._dtype(x), jnp.complexfloating):
        sign_dot = (ans_dot - np.real(ans_dot)) * sign
        ans_dot = np.real(ans_dot)
    else:
        sign_dot = jnp.zeros_like(sign)
    return (sign, ans), (sign_dot, ans_dot)
コード例 #10
0
ファイル: ndimage.py プロジェクト: gnecula/jax
def _map_coordinates(input, coordinates, order, mode, cval):
    input = jnp.asarray(input)
    coordinates = [jnp.asarray(c) for c in coordinates]
    cval = jnp.asarray(cval, input.dtype)

    if len(coordinates) != input.ndim:
        raise ValueError(
            'coordinates must be a sequence of length input.ndim, but '
            '{} != {}'.format(len(coordinates), input.ndim))

    index_fixer = _INDEX_FIXERS.get(mode)
    if index_fixer is None:
        raise NotImplementedError(
            'jax.scipy.ndimage.map_coordinates does not yet support mode {}. '
            'Currently supported modes are {}.'.format(mode,
                                                       set(_INDEX_FIXERS)))

    if mode == 'constant':
        is_valid = lambda index, size: (0 <= index) & (index < size)
    else:
        is_valid = lambda index, size: True

    if order == 0:
        interp_fun = _nearest_indices_and_weights
    elif order == 1:
        interp_fun = _linear_indices_and_weights
    else:
        raise NotImplementedError(
            'jax.scipy.ndimage.map_coordinates currently requires order<=1')

    valid_1d_interpolations = []
    for coordinate, size in zip(coordinates, input.shape):
        interp_nodes = interp_fun(coordinate)
        valid_interp = []
        for index, weight in interp_nodes:
            fixed_index = index_fixer(index, size)
            valid = is_valid(index, size)
            valid_interp.append((fixed_index, valid, weight))
        valid_1d_interpolations.append(valid_interp)

    outputs = []
    for items in itertools.product(*valid_1d_interpolations):
        indices, validities, weights = zip(*items)
        if all(valid is True for valid in validities):
            # fast path
            contribution = input[indices]
        else:
            all_valid = functools.reduce(operator.and_, validities)
            contribution = jnp.where(all_valid, input[indices], cval)
        outputs.append(_nonempty_prod(weights) * contribution)
    result = _nonempty_sum(outputs)
    if jnp.issubdtype(input.dtype, jnp.integer):
        result = _round_half_away_from_zero(result)
    return result.astype(input.dtype)
コード例 #11
0
ファイル: linalg.py プロジェクト: stevenchang8/jax
def triangular_solve(a, b, left_side=False, lower=False, transpose_a=False,
                     conjugate_a=False, unit_diagonal=False):
  conjugate_a = conjugate_a and jnp.issubdtype(lax.dtype(a), jnp.complexfloating)
  singleton = jnp.ndim(b) == jnp.ndim(a) - 1
  if singleton:
    b = jnp.expand_dims(b, -1 if left_side else -2)
  out = triangular_solve_p.bind(
      a, b, left_side=left_side, lower=lower, transpose_a=transpose_a,
      conjugate_a=conjugate_a, unit_diagonal=unit_diagonal)
  if singleton:
    out = out[..., 0] if left_side else out[..., 0, :]
  return out
コード例 #12
0
ファイル: linalg.py プロジェクト: xueeinstein/jax
def _slogdet_qr(a):
  # Implementation of slogdet using QR decomposition. One reason we might prefer
  # QR decomposition is that it is more amenable to a fast batched
  # implementation on TPU because of the lack of row pivoting.
  if jnp.issubdtype(lax.dtype(a), jnp.complexfloating):
    raise NotImplementedError("slogdet method='qr' not implemented for complex "
                              "inputs")
  n = a.shape[-1]
  a, taus = lax_linalg.geqrf(a)
  # The determinant of a triangular matrix is the product of its diagonal
  # elements. We are working in log space, so we compute the magnitude as the
  # the trace of the log-absolute values, and we compute the sign separately.
  log_abs_det = jnp.trace(jnp.log(jnp.abs(a)), axis1=-2, axis2=-1)
  sign_diag = jnp.prod(jnp.sign(jnp.diagonal(a, axis1=-2, axis2=-1)), axis=-1)
  # The determinant of a Householder reflector is -1. So whenever we actually
  # made a reflection (tau != 0), multiply the result by -1.
  sign_taus = jnp.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1).astype(sign_diag.dtype)
  return sign_diag * sign_taus, log_abs_det
コード例 #13
0
ファイル: special.py プロジェクト: GregCT/jax
def polygamma(n, x):
    assert jnp.issubdtype(lax.dtype(n), jnp.integer)
    n, x = _promote_args_inexact("polygamma", n, x)
    shape = lax.broadcast_shapes(n.shape, x.shape)
    return _polygamma(jnp.broadcast_to(n, shape), jnp.broadcast_to(x, shape))
コード例 #14
0
ファイル: linalg.py プロジェクト: xueeinstein/jax
def eigh_tridiagonal(d,
                     e,
                     *,
                     eigvals_only=False,
                     select='a',
                     select_range=None,
                     tol=None):
    if not eigvals_only:
        raise NotImplementedError(
            "Calculation of eigenvectors is not implemented")

    def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x):
        """Implements the Sturm sequence recurrence."""
        n = alpha.shape[0]
        zeros = jnp.zeros(x.shape, dtype=jnp.int32)
        ones = jnp.ones(x.shape, dtype=jnp.int32)

        # The first step in the Sturm sequence recurrence
        # requires special care if x is equal to alpha[0].
        def sturm_step0():
            q = alpha[0] - x
            count = jnp.where(q < 0, ones, zeros)
            q = jnp.where(alpha[0] == x, alpha0_perturbation, q)
            return q, count

        # Subsequent steps all take this form:
        def sturm_step(i, q, count):
            q = alpha[i] - beta_sq[i - 1] / q - x
            count = jnp.where(q <= pivmin, count + 1, count)
            q = jnp.where(q <= pivmin, jnp.minimum(q, -pivmin), q)
            return q, count

        # The first step initializes q and count.
        q, count = sturm_step0()

        # Peel off ((n-1) % blocksize) steps from the main loop, so we can run
        # the bulk of the iterations unrolled by a factor of blocksize.
        blocksize = 16
        i = 1
        peel = (n - 1) % blocksize
        unroll_cnt = peel

        def unrolled_steps(args):
            start, q, count = args
            for j in range(unroll_cnt):
                q, count = sturm_step(start + j, q, count)
            return start + unroll_cnt, q, count

        i, q, count = unrolled_steps((i, q, count))

        # Run the remaining steps of the Sturm sequence using a partially
        # unrolled while loop.
        unroll_cnt = blocksize

        def cond(iqc):
            i, q, count = iqc
            return jnp.less(i, n)

        _, _, count = lax.while_loop(cond, unrolled_steps, (i, q, count))
        return count

    alpha = jnp.asarray(d)
    beta = jnp.asarray(e)
    supported_dtypes = (jnp.float32, jnp.float64, jnp.complex64,
                        jnp.complex128)
    if alpha.dtype != beta.dtype:
        raise TypeError(
            "diagonal and off-diagonal values must have same dtype, "
            f"got {alpha.dtype} and {beta.dtype}")
    if alpha.dtype not in supported_dtypes or beta.dtype not in supported_dtypes:
        raise TypeError(
            "Only float32 and float64 inputs are supported as inputs "
            "to jax.scipy.linalg.eigh_tridiagonal, got "
            f"{alpha.dtype} and {beta.dtype}")
    n = alpha.shape[0]
    if n <= 1:
        return jnp.real(alpha)

    if jnp.issubdtype(alpha.dtype, jnp.complexfloating):
        alpha = jnp.real(alpha)
        beta_sq = jnp.real(beta * jnp.conj(beta))
        beta_abs = jnp.sqrt(beta_sq)
    else:
        beta_abs = jnp.abs(beta)
        beta_sq = jnp.square(beta)

    # Estimate the largest and smallest eigenvalues of T using the Gershgorin
    # circle theorem.
    off_diag_abs_row_sum = jnp.concatenate(
        [beta_abs[:1], beta_abs[:-1] + beta_abs[1:], beta_abs[-1:]], axis=0)
    lambda_est_max = jnp.amax(alpha + off_diag_abs_row_sum)
    lambda_est_min = jnp.amin(alpha - off_diag_abs_row_sum)
    # Upper bound on 2-norm of T.
    t_norm = jnp.maximum(jnp.abs(lambda_est_min), jnp.abs(lambda_est_max))

    # Compute the smallest allowed pivot in the Sturm sequence to avoid
    # overflow.
    finfo = np.finfo(alpha.dtype)
    one = np.ones([], dtype=alpha.dtype)
    safemin = np.maximum(one / finfo.max, (one + finfo.eps) * finfo.tiny)
    pivmin = safemin * jnp.maximum(1, jnp.amax(beta_sq))
    alpha0_perturbation = jnp.square(finfo.eps * beta_abs[0])
    abs_tol = finfo.eps * t_norm
    if tol is not None:
        abs_tol = jnp.maximum(tol, abs_tol)

    # In the worst case, when the absolute tolerance is eps*lambda_est_max and
    # lambda_est_max = -lambda_est_min, we have to take as many bisection steps
    # as there are bits in the mantissa plus 1.
    # The proof is left as an exercise to the reader.
    max_it = finfo.nmant + 1

    # Determine the indices of the desired eigenvalues, based on select and
    # select_range.
    if select == 'a':
        target_counts = jnp.arange(n, dtype=jnp.int32)
    elif select == 'i':
        if select_range[0] > select_range[1]:
            raise ValueError('Got empty index range in select_range.')
        target_counts = jnp.arange(select_range[0],
                                   select_range[1] + 1,
                                   dtype=jnp.int32)
    elif select == 'v':
        # TODO(phawkins): requires dynamic shape support.
        raise NotImplementedError("eigh_tridiagonal(..., select='v') is not "
                                  "implemented")
    else:
        raise ValueError("'select must have a value in {'a', 'i', 'v'}.")

    # Run binary search for all desired eigenvalues in parallel, starting from
    # the interval lightly wider than the estimated
    # [lambda_est_min, lambda_est_max].
    fudge = 2.1  # We widen starting interval the Gershgorin interval a bit.
    norm_slack = jnp.array(n, alpha.dtype) * fudge * finfo.eps * t_norm
    lower = lambda_est_min - norm_slack - 2 * fudge * pivmin
    upper = lambda_est_max + norm_slack + fudge * pivmin

    # Pre-broadcast the scalars used in the Sturm sequence for improved
    # performance.
    target_shape = jnp.shape(target_counts)
    lower = jnp.broadcast_to(lower, shape=target_shape)
    upper = jnp.broadcast_to(upper, shape=target_shape)
    mid = 0.5 * (upper + lower)
    pivmin = jnp.broadcast_to(pivmin, target_shape)
    alpha0_perturbation = jnp.broadcast_to(alpha0_perturbation, target_shape)

    # Start parallel binary searches.
    def cond(args):
        i, lower, _, upper = args
        return jnp.logical_and(jnp.less(i, max_it),
                               jnp.less(abs_tol, jnp.amax(upper - lower)))

    def body(args):
        i, lower, mid, upper = args
        counts = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, mid)
        lower = jnp.where(counts <= target_counts, mid, lower)
        upper = jnp.where(counts > target_counts, mid, upper)
        mid = 0.5 * (lower + upper)
        return i + 1, lower, mid, upper

    _, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper))
    return mid
コード例 #15
0
 def _to_inexact_type(type):
     return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_
コード例 #16
0
ファイル: ndimage.py プロジェクト: gnecula/jax
def _round_half_away_from_zero(a):
    return a if jnp.issubdtype(a.dtype, jnp.integer) else lax.round(a)