예제 #1
0
def _fft_norm(s, func_name, norm):
  if norm is None or norm == "backward":
    return 1
  elif norm == "ortho":
    return jnp.sqrt(jnp.prod(s)) if func_name.startswith('i') else 1/jnp.sqrt(jnp.prod(s))
  elif norm == "forward":
    return jnp.prod(s) if func_name.startswith('i') else 1/jnp.prod(s)
  raise ValueError(f'Invalid norm value {norm}; should be "backward",'
                    '"ortho" or "forward".')
예제 #2
0
파일: linalg.py 프로젝트: xueeinstein/jax
def _slogdet_qr(a):
  # Implementation of slogdet using QR decomposition. One reason we might prefer
  # QR decomposition is that it is more amenable to a fast batched
  # implementation on TPU because of the lack of row pivoting.
  if jnp.issubdtype(lax.dtype(a), jnp.complexfloating):
    raise NotImplementedError("slogdet method='qr' not implemented for complex "
                              "inputs")
  n = a.shape[-1]
  a, taus = lax_linalg.geqrf(a)
  # The determinant of a triangular matrix is the product of its diagonal
  # elements. We are working in log space, so we compute the magnitude as the
  # the trace of the log-absolute values, and we compute the sign separately.
  log_abs_det = jnp.trace(jnp.log(jnp.abs(a)), axis1=-2, axis2=-1)
  sign_diag = jnp.prod(jnp.sign(jnp.diagonal(a, axis1=-2, axis2=-1)), axis=-1)
  # The determinant of a Householder reflector is -1. So whenever we actually
  # made a reflection (tau != 0), multiply the result by -1.
  sign_taus = jnp.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1).astype(sign_diag.dtype)
  return sign_diag * sign_taus, log_abs_det
예제 #3
0
파일: linalg.py 프로젝트: cloudhan/jax
def _slogdet_lu(a):
    dtype = lax.dtype(a)
    lu, pivot, _ = lax_linalg.lu(a)
    diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
    is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1)
    iota = lax.expand_dims(jnp.arange(a.shape[-1]), range(pivot.ndim - 1))
    parity = jnp.count_nonzero(pivot != iota, axis=-1)
    if jnp.iscomplexobj(a):
        sign = jnp.prod(diag / jnp.abs(diag), axis=-1)
    else:
        sign = jnp.array(1, dtype=dtype)
        parity = parity + jnp.count_nonzero(diag < 0, axis=-1)
    sign = jnp.where(is_zero, jnp.array(0, dtype=dtype),
                     sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype))
    logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype),
                       jnp.sum(jnp.log(jnp.abs(diag)), axis=-1))
    return sign, jnp.real(logdet)
예제 #4
0
파일: linalg.py 프로젝트: ahoenselaar/jax
def slogdet(a):
    a = _promote_arg_dtypes(jnp.asarray(a))
    dtype = lax.dtype(a)
    a_shape = jnp.shape(a)
    if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
        msg = "Argument to slogdet() must have shape [..., n, n], got {}"
        raise ValueError(msg.format(a_shape))
    lu, pivot, _ = lax_linalg.lu(a)
    diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
    is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1)
    parity = jnp.count_nonzero(pivot != jnp.arange(a_shape[-1]), axis=-1)
    if jnp.iscomplexobj(a):
        sign = jnp.prod(diag / jnp.abs(diag), axis=-1)
    else:
        sign = jnp.array(1, dtype=dtype)
        parity = parity + jnp.count_nonzero(diag < 0, axis=-1)
    sign = jnp.where(is_zero, jnp.array(0, dtype=dtype),
                     sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype))
    logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype),
                       jnp.sum(jnp.log(jnp.abs(diag)), axis=-1))
    return sign, jnp.real(logdet)