Beispiel #1
0
 def body_fun(i, p_val):
     coeff_0 = d0_mask_3d[i]
     coeff_1 = d1_mask_3d[i]
     h = (jnp.einsum(
         'ij,ijk->ijk', coeff_0,
         jnp.einsum('ijk,k->ijk', jnp.roll(p_val, shift=1, axis=1), x)) -
          jnp.einsum('ij,ijk->ijk', coeff_1, jnp.roll(
              p_val, shift=2, axis=1)))
     p_val = p_val + h
     return p_val
Beispiel #2
0
def _roots_with_zeros(p, num_leading_zeros):
    # Avoid lapack errors when p is all zero
    p = _where(len(p) == num_leading_zeros, 1.0, p)
    # Roll any leading zeros to the end & compute the roots
    roots = _roots_no_zeros(roll(p, -num_leading_zeros))
    # Sort zero roots to the end.
    roots = lax.sort_key_val(roots == 0, roots)[1]
    # Set roots associated with num_leading_zeros to NaN
    return _where(
        arange(roots.size) < roots.size - num_leading_zeros, roots,
        complex(np.nan, np.nan))
Beispiel #3
0
def ifftshift(x, axes=None):
  x = jnp.asarray(x)
  if axes is None:
    axes = tuple(range(x.ndim))
    shift = [-(dim // 2) for dim in x.shape]
  elif isinstance(axes, int):
    shift = -(x.shape[axes] // 2)
  else:
    shift = [-(x.shape[ax] // 2) for ax in axes]

  return jnp.roll(x, shift, axes)