Пример #1
0
    def body(k, state):
        pivot, perm, a = state
        m_idx = jnp.arange(m)
        n_idx = jnp.arange(n)

        if jnp.issubdtype(a.dtype, jnp.complexfloating):
            t = a[:, k]
            magnitude = jnp.abs(jnp.real(t)) + jnp.abs(jnp.imag(t))
        else:
            magnitude = jnp.abs(a[:, k])
        i = jnp.argmax(jnp.where(m_idx >= k, magnitude, -jnp.inf))
        pivot = ops.index_update(pivot, ops.index[k], i)

        a = ops.index_update(a, ops.index[[k, i], ], a[[i, k], ])

        perm = ops.index_update(perm, ops.index[[i, k], ], perm[[k, i], ])

        # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes
        x = a[k, k]
        a = ops.index_update(a, ops.index[:, k],
                             jnp.where(m_idx > k, a[:, k] / x, a[:, k]))

        # a[k+1:, k+1:] -= jnp.outer(a[k+1:, k], a[k, k+1:])
        a = a - jnp.where(
            (m_idx[:, None] > k) & (n_idx > k), jnp.outer(a[:, k], a[k, :]),
            jnp.array(0, dtype=a.dtype))
        return pivot, perm, a
Пример #2
0
def csd(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
        detrend='constant', return_onesided=True, scaling='density',
        axis=-1, average='mean'):
  freqs, _, Pxy = _spectral_helper(x, y, fs, window, nperseg, noverlap, nfft,
                                  detrend, return_onesided, scaling, axis,
                                  mode='psd')
  if y is not None:
    Pxy = Pxy + 0j  # Ensure complex output when x is not y

  # Average over windows.
  if Pxy.ndim >= 2 and Pxy.size > 0:
    if Pxy.shape[-1] > 1:
      if average == 'median':
        bias = signal_helper._median_bias(Pxy.shape[-1]).astype(Pxy.dtype)
        if jnp.iscomplexobj(Pxy):
          Pxy = (jnp.median(jnp.real(Pxy), axis=-1)
                  + 1j * jnp.median(jnp.imag(Pxy), axis=-1))
        else:
          Pxy = jnp.median(Pxy, axis=-1)
        Pxy /= bias
      elif average == 'mean':
        Pxy = Pxy.mean(axis=-1)
      else:
        raise ValueError(f'average must be "median" or "mean", got {average}')
    else:
      Pxy = jnp.reshape(Pxy, Pxy.shape[:-1])

  return freqs, Pxy
Пример #3
0
def _sph_harm(m: jnp.ndarray, n: jnp.ndarray, theta: jnp.ndarray,
              phi: jnp.ndarray, n_max: int) -> jnp.ndarray:
    """Computes the spherical harmonics."""

    cos_colatitude = jnp.cos(phi)

    legendre = _gen_associated_legendre(n_max, cos_colatitude, True)
    legendre_val = legendre[abs(m), n, jnp.arange(len(n))]

    angle = abs(m) * theta
    vandermonde = lax.complex(jnp.cos(angle), jnp.sin(angle))
    harmonics = lax.complex(legendre_val * jnp.real(vandermonde),
                            legendre_val * jnp.imag(vandermonde))

    # Negative order.
    harmonics = jnp.where(m < 0, (-1.0)**abs(m) * jnp.conjugate(harmonics),
                          harmonics)

    return harmonics