コード例 #1
0
def canonicalize_coordinates(q):
    unsafephi = jnp.sqrt(jnp.sum(q**2))
    phi = jnp.maximum(unsafephi, jnp.array(0.01))

    max_phi = jnp.pi
    canonical_phi = jnp.fmod(phi + max_phi, 2.0 * max_phi) - max_phi

    return jax.lax.select(
        phi > max_phi,  # and phi == unsafephi
        (canonical_phi / phi) * q,
        q,
    )
コード例 #2
0
def onnx_mod(a, b, fmod=0):
    if fmod:
        return jnp.fmod(a, b)
    else:
        return jnp.mod(a, b)
コード例 #3
0
def fmod(x1, x2):
  if isinstance(x1, JaxArray): x1 = x1.value
  if isinstance(x2, JaxArray): x2 = x2.value
  return JaxArray(jnp.fmod(x1, x2))
コード例 #4
0
ファイル: mylunarlander.py プロジェクト: i-abr/KLE3
def wrap2pi(th):
    x = np.fmod(th + np.pi, 2.0 * np.pi)
    x = cond(x < 0, x, lambda x: x + 2.0 * np.pi, x, lambda x: x)
    return x - np.pi
コード例 #5
0
def frft(f, a):
    """
    fast fractional fourier transform.
    Parameters
        f : [jax.]numpy array
            The signal to be transformed.
        a : float
            fractional power
    Returns
        data : [jax.]numpy array
            The transformed signal.
    reference:
        https://github.com/nanaln/python_frft
    """
    f = device_put(f)
    a = device_put(a)

    ret = jnp.zeros_like(f, dtype=jnp.complex64)
    f = f.astype(jnp.complex64)
    N = f.shape[0]

    shft = jnp.fmod(jnp.arange(N) + jnp.fix(N / 2), N).astype(int)
    sN = jnp.sqrt(N)
    a = jnp.remainder(a, 4.0)

    TRUE = jnp.array(True)
    FALSE = jnp.array(False)

    # simple cases
    ret, done = lax.cond(
        a == 0.0,
        None,
        lambda _: (f, TRUE),
        None,
        lambda _: (ret, FALSE))

    ret, done = lax.cond(
        a == 2.0,
        None,
        lambda _: (jnp.flipud(f), TRUE),
        None,
        lambda _: (ret, done))

    ret, done = lax.cond(
        a == 1.0,
        None,
        lambda _: (index_update(ret, index[shft], jnp.fft.fft(f[shft]) / sN), TRUE),
        None,
        lambda _: (ret, done))

    ret, done = lax.cond(
        a == 3.0,
        None,
        lambda _: (index_update(ret, index[shft], jnp.fft.ifft(f[shft]) * sN), TRUE),
        None,
        lambda _: (ret, done))

    @jit
    def sincinterp(x):
        N = x.shape[0]
        y = jnp.zeros(2 * N -1, dtype=x.dtype)
        y = index_update(y, index[:2 * N:2], x)
        xint = fftconvolve(
           y[:2 * N],
           jnp.sinc(jnp.arange(-(2 * N - 3), (2 * N - 2)).T / 2),
        )
        return xint[2 * N - 3: -2 * N + 3]

    @jit
    def chirp_opts(a, f):
        # the general case for 0.5 < a < 1.5
        alpha = a * jnp.pi / 2
        tana2 = jnp.tan(alpha / 2)
        sina = jnp.sin(alpha)
        f = jnp.hstack((jnp.zeros(N - 1), sincinterp(f), jnp.zeros(N - 1))).T

        # chirp premultiplication
        chrp = jnp.exp(-1j * jnp.pi / N * tana2 / 4 *
                         jnp.arange(-2 * N + 2, 2 * N - 1).T ** 2)
        f = chrp * f

        # chirp convolution
        c = jnp.pi / N / sina / 4
        ret = fftconvolve(
            jnp.exp(1j * c * jnp.arange(-(4 * N - 4), 4 * N - 3).T ** 2),
            f,
        )
        ret = ret[4 * N - 4:8 * N - 7] * jnp.sqrt(c / jnp.pi)

        # chirp post multiplication
        ret = chrp * ret

        # normalizing constant
        ret = jnp.exp(-1j * (1 - a) * jnp.pi / 4) * ret[N - 1:-N + 1:2]

        return ret

    def other_cases(a, f):

        a, f = lax.cond(
            a > 2.0,
            None,
            lambda _: (a - 2.0, jnp.flipud(f)),
            None,
            lambda _: (a, f))

        a, f = lax.cond(
            a > 1.5,
            None,
            lambda _: (a - 1.0, index_update(f, index[shft], jnp.fft.fft(f[shft]) / sN)),
            None,
            lambda _: (a, f))

        a, f = lax.cond(
            a < 0.5,
            None,
            lambda _: (a + 1.0, index_update(f, index[shft], jnp.fft.ifft(f[shft]) * sN)),
            None,
            lambda _: (a, f))

        return chirp_opts(a, f)

    ret = lax.cond(
        done,
        None,
        lambda _: ret,
        None,
        lambda _: other_cases(a, f))

    return ret
コード例 #6
0
dts = [2**(-x) for x in range(1, 10, 1)]
n_samples = 2**15

theoretical_mean = 1.10714532096375
theoretical_variance = 0.5951002076847987
theoretical_sigma = np.sqrt(theoretical_variance / n_samples)

bias_dict = dict()

for dt in dts:
    solver.dt = dt
    solution = solver.solve_many(problem, n_samples, seed=0)
    terminal_r = solution['solution_values'][:, -1, 0]
    terminal_phi = solution['solution_values'][:, -1, 1]
    terminal_phi_canonical = jnp.fmod(terminal_phi + jnp.pi,
                                      2 * jnp.pi) - jnp.pi  # phi in [-pi,pi]

    expected_phi = float(jnp.mean(terminal_phi_canonical))
    std_mean_phi = float(jnp.std(terminal_phi_canonical) / jnp.sqrt(n_samples))

    measured_bias = expected_phi - theoretical_mean
    sigma_bias = std_mean_phi

    print(f'{int(1/dt)=}')
    print(f'{measured_bias=}')
    print(f'{sigma_bias=}')

    bias_dict[int(1 / dt)] = (measured_bias, sigma_bias)

for (k, (mu, sigma)) in bias_dict.items():
    print('{' + f'{k},{mu},{sigma}' + '},')