示例#1
0
文件: signal.py 项目: jbampton/jax
def convolve2d(in1, in2, mode='full', boundary='fill', fillvalue=0,
               precision=None):
  if boundary != 'fill' or fillvalue != 0:
    raise NotImplementedError("convolve2d() only supports boundary='fill', fillvalue=0")
  if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2:
    raise ValueError("convolve2d() only supports 2-dimensional inputs.")
  return _convolve_nd(in1, in2, mode, precision=precision)
示例#2
0
文件: linalg.py 项目: xueeinstein/jax
def _solve_triangular(a, b, trans, lower, unit_diagonal):
    if trans == 0 or trans == "N":
        transpose_a, conjugate_a = False, False
    elif trans == 1 or trans == "T":
        transpose_a, conjugate_a = True, False
    elif trans == 2 or trans == "C":
        transpose_a, conjugate_a = True, True
    else:
        raise ValueError(f"Invalid 'trans' value {trans}")

    a, b = _promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))

    # lax_linalg.triangular_solve only supports matrix 'b's at the moment.
    b_is_vector = jnp.ndim(a) == jnp.ndim(b) + 1
    if b_is_vector:
        b = b[..., None]
    out = lax_linalg.triangular_solve(a,
                                      b,
                                      left_side=True,
                                      lower=lower,
                                      transpose_a=transpose_a,
                                      conjugate_a=conjugate_a,
                                      unit_diagonal=unit_diagonal)
    if b_is_vector:
        return out[..., 0]
    else:
        return out
示例#3
0
文件: signal.py 项目: GJBoth/jax
def correlate2d(in1,
                in2,
                mode='full',
                boundary='fill',
                fillvalue=0,
                precision=None):
    if boundary != 'fill' or fillvalue != 0:
        raise NotImplementedError(
            "correlate2d() only supports boundary='fill', fillvalue=0")
    if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2:
        raise ValueError("correlate2d() only supports 2-dimensional inputs.")

    swap = all(s1 <= s2 for s1, s2 in zip(in1.shape, in2.shape))
    same_shape = all(s1 == s2 for s1, s2 in zip(in1.shape, in2.shape))

    if mode == "same":
        in1, in2 = jnp.flip(in1), in2.conj()
        result = jnp.flip(_convolve_nd(in1, in2, mode, precision=precision))
    elif mode == "valid":
        if swap and not same_shape:
            in1, in2 = jnp.flip(in2), in1.conj()
            result = _convolve_nd(in1, in2, mode, precision=precision)
        else:
            in1, in2 = jnp.flip(in1), in2.conj()
            result = jnp.flip(_convolve_nd(in1, in2, mode,
                                           precision=precision))
    else:
        if swap:
            in1, in2 = jnp.flip(in2), in1.conj()
            result = _convolve_nd(in1, in2, mode, precision=precision).conj()
        else:
            in1, in2 = jnp.flip(in1), in2.conj()
            result = jnp.flip(_convolve_nd(in1, in2, mode,
                                           precision=precision))
    return result
示例#4
0
def convolve(in1, in2, mode='full', method='auto', precision=None):
    if method != 'auto':
        warnings.warn("convolve() ignores method argument")
    if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype(
            in2.dtype, jnp.complexfloating):
        raise NotImplementedError("convolve() does not support complex inputs")
    if jnp.ndim(in1) != 1 or jnp.ndim(in2) != 1:
        raise ValueError("convolve() only supports 1-dimensional inputs.")
    return _convolve_nd(in1, in2, mode, precision=precision)
示例#5
0
def triangular_solve(a,
                     b,
                     left_side: bool = False,
                     lower: bool = False,
                     transpose_a: bool = False,
                     conjugate_a: bool = False,
                     unit_diagonal: bool = False):
    r"""Triangular solve.

  Solves either the matrix equation

  .. math::
    \mathit{op}(A) . X = B

  if ``left_side`` is ``True`` or

  .. math::
    X . \mathit{op}(A) = B

  if ``left_side`` is ``False``.

  ``A`` must be a lower or upper triangular square matrix, and where
  :math:`\mathit{op}(A)` may either transpose :math:`A` if ``transpose_a``
  is ``True`` and/or take its complex conjugate if ``conjugate_a`` is ``True``.

  Args:
    a: A batch of matrices with shape ``[..., m, m]``.
    b: A batch of matrices with shape ``[..., m, n]`` if ``left_side`` is
      ``True`` or shape ``[..., n, m]`` otherwise.
    left_side: describes which of the two matrix equations to solve; see above.
    lower: describes which triangle of ``a`` should be used. The other triangle
      is ignored.
    transpose_a: if ``True``, the value of ``a`` is transposed.
    conjugate_a: if ``True``, the complex conjugate of ``a`` is used in the
      solve. Has no effect if ``a`` is real.
    unit_diagonal: if ``True``, the diagonal of ``a`` is assumed to be unit
      (all 1s) and not accessed.

  Returns:
    A batch of matrices the same shape and dtype as ``b``.
  """
    conjugate_a = conjugate_a and jnp.issubdtype(lax.dtype(a),
                                                 jnp.complexfloating)
    singleton = jnp.ndim(b) == jnp.ndim(a) - 1
    if singleton:
        b = jnp.expand_dims(b, -1 if left_side else -2)
    out = triangular_solve_p.bind(a,
                                  b,
                                  left_side=left_side,
                                  lower=lower,
                                  transpose_a=transpose_a,
                                  conjugate_a=conjugate_a,
                                  unit_diagonal=unit_diagonal)
    if singleton:
        out = out[..., 0] if left_side else out[..., 0, :]
    return out
示例#6
0
def triangular_solve(a, b, left_side=False, lower=False, transpose_a=False,
                     conjugate_a=False, unit_diagonal=False):
  conjugate_a = conjugate_a and jnp.issubdtype(lax.dtype(a), jnp.complexfloating)
  singleton = jnp.ndim(b) == jnp.ndim(a) - 1
  if singleton:
    b = jnp.expand_dims(b, -1 if left_side else -2)
  out = triangular_solve_p.bind(
      a, b, left_side=left_side, lower=lower, transpose_a=transpose_a,
      conjugate_a=conjugate_a, unit_diagonal=unit_diagonal)
  if singleton:
    out = out[..., 0] if left_side else out[..., 0, :]
  return out
示例#7
0
文件: linalg.py 项目: ahoenselaar/jax
def inv(a):
    if jnp.ndim(a) < 2 or a.shape[-1] != a.shape[-2]:
        raise ValueError(
            f"Argument to inv must have shape [..., n, n], got {a.shape}.")
    return solve(
        a, lax.broadcast(jnp.eye(a.shape[-1], dtype=lax.dtype(a)),
                         a.shape[:-2]))
示例#8
0
def convolve2d(in1,
               in2,
               mode='full',
               boundary='fill',
               fillvalue=0,
               precision=None):
    if boundary != 'fill' or fillvalue != 0:
        raise NotImplementedError(
            "convolve2d() only supports boundary='fill', fillvalue=0")
    if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype(
            in2.dtype, jnp.complexfloating):
        raise NotImplementedError(
            "convolve2d() does not support complex inputs")
    if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2:
        raise ValueError("convolve2d() only supports 2-dimensional inputs.")
    return _convolve_nd(in1, in2, mode, precision=precision)
示例#9
0
文件: eigh.py 项目: cloudhan/jax
def _mask(x, dims, alternative=0):
  """Masks `x` up to the dynamic shape `dims`.

  Replaces values outside those dimensions with `alternative`. `alternative` is
  broadcast with `x`.
  """
  assert jnp.ndim(x) == len(dims)
  mask = None
  for i, d in enumerate(dims):
    if d is not None:
      mask_dim_i = lax.broadcasted_iota(jnp.int32, x.shape, i) < d
      mask = mask_dim_i if mask is None else (mask & mask_dim_i)
  return x if mask is None else jnp.where(mask, x, alternative)
示例#10
0
def block_diag(*arrs):
  if len(arrs) == 0:
    arrs = [jnp.zeros((1, 0))]
  arrs = jnp._promote_dtypes(*arrs)
  bad_shapes = [i for i, a in enumerate(arrs) if jnp.ndim(a) > 2]
  if bad_shapes:
    raise ValueError("Arguments to jax.scipy.linalg.block_diag must have at "
                     "most 2 dimensions, got {} at argument {}."
                     .format(arrs[bad_shapes[0]], bad_shapes[0]))
  arrs = [jnp.atleast_2d(a) for a in arrs]
  acc = arrs[0]
  dtype = lax.dtype(acc)
  for a in arrs[1:]:
    _, c = a.shape
    a = lax.pad(a, dtype.type(0), ((0, 0, 0), (acc.shape[-1], 0, 0)))
    acc = lax.pad(acc, dtype.type(0), ((0, 0, 0), (0, c, 0)))
    acc = lax.concatenate([acc, a], dimension=0)
  return acc