Example #1
0
def _eval_expint_k(A, B, x):
    # helper function for all subsequent intervals
    A, B = [jnp.array(U, dtype=x.dtype) for U in [A, B]]
    one = _lax_const(x, 1.0)
    w = one / x
    f = jnp.polyval(A, w) / jnp.polyval(B, w)
    f = w * f + one
    return jnp.exp(x) * w * f
Example #2
0
 def _stats(self, s):
     p = np.exp(s * s)
     mu = np.sqrt(p)
     mu2 = p * (p - 1)
     g1 = np.sqrt(p - 1) * (2 + p)
     g2 = np.polyval([1, 2, 3, 0, -6.0], p)
     return mu, mu2, g1, g2
Example #3
0
def odeint(ofunc, y0, t, args=(), rtol=1e-7, atol=1e-9, return_evals=False):

    if len(args) > 0:
        func = lambda y, t: ofunc(y, t, *args)
    else:
        func = ofunc

    # Reverse time if necessary.
    t = np.array(t)
    if t[-1] < t[0]:
        t = -t
        reversed_func = func
        func = lambda y, t: -reversed_func(y, -t)
    assert np.all(
        t[1:] > t[:-1]), 't must be strictly increasing or decreasing'

    f0 = func(y0, t[0])
    dt = initial_step_size(func, t[0], y0, 4, rtol, atol, f0)
    interp_coeff = np.array([y0] * 5)

    solution = [y0]
    cur_t = t[0]
    cur_y = y0
    cur_f = f0

    if return_evals:
        evals = [(t0, y0, f0)]

    for output_t in t[1:]:
        # Interpolate through to the next time point, integrating as necessary.
        while cur_t < output_t:
            next_t = cur_t + dt
            assert next_t > cur_t, 'underflow in dt {}'.format(dt)

            next_y, next_f, next_y_error, k =\
                runge_kutta_step(func, cur_y, cur_f, cur_t, dt)
            error_ratios = error_ratio(next_y_error, atol, rtol, cur_y, next_y)

            if np.all(error_ratios <= 1):  # Accept the step?
                interp_coeff = interp_fit_dopri(cur_y, next_y, k, dt)
                cur_y = next_y
                cur_f = next_f
                last_t = cur_t
                cur_t = next_t

                if return_evals:
                    evals.append((cur_t, cur_y, cur_f))

            dt = optimal_step_size(dt, error_ratios)

        relative_output_time = (output_t - last_t) / (cur_t - last_t)
        output_y = np.polyval(interp_coeff, relative_output_time)
        solution.append(output_y)
    if return_evals:
        return np.stack(solution), zip(*evals)
    return np.stack(solution)
Example #4
0
    def _fori_body_fun(func, i, val):
        """Internal fori_loop body to interpolate an integral at each timestep."""
        t, cur_y, cur_f, cur_t, dt, last_t, interp_coeff, solution = val
        cur_y, cur_f, cur_t, dt, last_t, interp_coeff = jax.lax.while_loop(
            lambda x: x[2] < t[i], functools.partial(_while_body_fun, func),
            (cur_y, cur_f, cur_t, dt, last_t, interp_coeff))

        relative_output_time = (t[i] - last_t) / (cur_t - last_t)
        out_x = np.polyval(interp_coeff, relative_output_time)

        return (t, cur_y, cur_f, cur_t, dt, last_t, interp_coeff,
                jax.ops.index_update(solution, jax.ops.index[i, :], out_x))
Example #5
0
def j0(x):
    """Bessel function of the 1st kind, order=0.

    Args:
       x: x

    Returns:
       J0
    """
    x = jnp.where(x > 0., x, -x)

    z = x * x
    ret = 1. - z / 4.

    p = (z - DR1) * (z - DR2)
    p = p * jnp.polyval(RP, z) / jnp.polyval(RQ, z)
    ret = jnp.where(x < 1e-5, ret, p)

    # required for autograd not to fail when x includes 0
    xinv5 = jnp.where(x <= 5., 0., 1. / (x + 1e-10))
    w = 5.0 * xinv5
    z = w * w
    p = jnp.polyval(PP, z) / jnp.polyval(PQ, z)
    q = jnp.polyval(QP, z) / jnp.polyval(QQ, z)
    xn = x - PIO4
    p = p * jnp.cos(xn) - w * q * jnp.sin(xn)
    ret = jnp.where(x <= 5., ret, p * SQ2OPI * jnp.sqrt(xinv5))

    return ret
Example #6
0
 def _stats(self, b, moments='mv'):
     mu, mu2, g1, g2 = None, None, None, None
     if 'm' in moments:
         mask = b > 1
         bt = np.extract(mask, b)
         mu = np.where(mask, bt / (bt - 1.0), np.inf)
     if 'v' in moments:
         mask = b > 2
         bt = np.extract(mask, b)
         mu2 = np.where(mask, bt / (bt - 2.0) / (bt - 1.0) ** 2, np.inf)
     if 's' in moments:
         mask = b > 3
         bt = np.extract(mask, b)
         vals = 2 * (bt + 1.0) * np.sqrt(bt - 2.0) / ((bt - 3.0) * np.sqrt(bt))
         g1 = np.where(mask, vals, np.nan)
     if 'k' in moments:
         mask = b > 4
         bt = np.extract(mask, b)
         vals = (6.0 * np.polyval([1.0, 1.0, -6, -2], bt)
                 / np.polyval([1.0, -7.0, 12.0, 0.0], bt))
         g2 = np.where(mask, vals, np.nan)
     return mu, mu2, g1, g2
Example #7
0
def _expint1(x):
    # 0 < x <= 2
    A = [
        -5.350447357812542947283e0,
        2.185049168816613393830e2,
        -4.176572384826693777058e3,
        5.541176756393557601232e4,
        -3.313381331178144034309e5,
        1.592627163384945414220e6,
    ]
    B = [
        1.0,
        -5.250547959112862969197e1,
        1.259616186786790571525e3,
        -1.756549581973534652631e4,
        1.493062117002725991967e5,
        -7.294949239640527645655e5,
        1.592627163384945429726e6,
    ]
    A, B = [jnp.array(U, dtype=x.dtype) for U in [A, B]]
    f = jnp.polyval(A, x) / jnp.polyval(B, x)
    return x * f + jnp.euler_gamma + jnp.log(x)
Example #8
0
def product_log(w):
    """
    fifth order approximation to lambertw between -1/e and 0.
    Args:
        w:

    Returns:

    """
    Q = jnp.array([0., 1., -1., 3. / 2., -8. / 3.])
    E = jnp.exp(1.)

    P = jnp.array([
        -1.,
        jnp.sqrt(2 * E), -(2 * E) / 3., (11 * E**1.5) / (18. * jnp.sqrt(2)),
        -(43 * E**2) / 135.,
        +(769 * E**2.5) / (2160. * jnp.sqrt(2)), -(1768 * E**3) / 8505.,
        (680863 * E**3.5) / (2.7216e6 * jnp.sqrt(2)), -(3926 * E**4) / 25515.,
        (226287557 * E**4.5) / (1.1757312e9 * jnp.sqrt(2)),
        -(23105476 * E**5) / 1.89448875e8
    ])
    return jnp.where(w > -0.5 / E, jnp.polyval(Q[::-1], w),
                     jnp.polyval(P[::-1], jnp.sqrt(jnp.exp(-1.) + w)))
def cumulative_tomographic_weight_dimensionless_polynomial(
        Q, gamma_prime, n, w1, w2):
    """
    Computes log P(|x1-x2 + t1*p1 - t2*p2|^2 < lambda)

    Note, that this is invariant to scaling of the input vectors by a scalar,

        P(alpha*|x1-x2 + t1*p1 - t2*p2|^2 < alpha*lambda).

    Therefore a dimensionless form is ,

        P(|n + t1*w1 - t2*w2|^2 < lambda')

    where,

        n = x1-x2 / |x1-x2| is a unit vector.
        w1 = p1 / |x1-x2|
        w2 = p2 / |x1-x2|
        lambda' = lambda / |x1-x2|^2

    Args:
        Q:
        gamma:
        x1:
        x2:
        p1:
        p2:

    Returns:

    """
    parabolic = False
    if w2 is None:
        parabolic = True
        w2 = w1
    A = w1 @ w1
    C = w2 @ w2
    B = -2. * w1 @ w2
    D = 2. * n @ w1
    E = -2. * n @ w2
    F = 1. - gamma_prime
    param = jnp.asarray([1., A, C, B, D, E, F])
    Q = Q.reshape((-1, 7))
    coefficients = Q @ param
    return jnp.polyval(coefficients, gamma_prime)
Example #10
0
def calc_tb_params(dr, cutoff, kwargs):
    """Select parameters for species pair from kwargs dictionary and set onsite terms to zero
    Args:
        dr: 2D matrix of distances of particles
        cutoff: float for cutoff distance in Angstrom
        kwargs: Dict of 2D matrix of slyer-koster parameters
    Returns: parameters for sk calculation
    """
    param_count = 10  # 10 for d orbitals? # 4 for S orbitalas? # maybe use 17 for l, m, n? later might be len(kwargs)
    # sk_key_list = ['Vsss', 'Vsps', 'Vpps', 'Vppp']  # , 'Vsds', 'Vpds', 'Vpdp', 'Vdds', 'Vddp', 'Vddd']
    sk_key_list = ['Vsss', 'Vsps', 'Vpps', 'Vppp', 'Vsds', 'Vpds', 'Vpdp', 'Vdds', 'Vddp', 'Vddd']
    # , 'VSSs', 'VsSs', 'VSps', 'VSds']
    param = jnp.repeat(jnp.expand_dims(jnp.zeros((dr.shape)), axis=-1), param_count, axis=-1)
    counter = 0

    for key in sk_key_list:
        # param = dist_dependent_params(dr, kwargs, key)  # return value using dist_dependent_prams
        param = param.at[:, :, :, counter].set(
            jnp.where(jnp.logical_or(dr <= 0.1, dr > cutoff), 0.0, jnp.polyval(kwargs[key][-1::-1], dr*1.88973)))
        counter += 1
    return param  # interactions[species_a, species_b]
Example #11
0
  def scan_fun(carry, target_t):

    def cond_fun(state):
      i, _, _, t, dt, _, _ = state
      return (t < target_t) & (i < mxstep) & (dt > 0)

    def body_fun(state):
      i, y, f, t, dt, last_t, interp_coeff = state
      next_y, next_f, next_y_error, k = runge_kutta_step(func_, y, f, t, dt)
      next_t = t + dt
      error_ratio = mean_error_ratio(next_y_error, rtol, atol, y, next_y)
      new_interp_coeff = interp_fit_dopri(y, next_y, k, dt)
      dt = optimal_step_size(dt, error_ratio)

      new = [i + 1, next_y, next_f, next_t, dt,      t, new_interp_coeff]
      old = [i + 1,      y,      f,      t, dt, last_t,     interp_coeff]
      return map(partial(jnp.where, error_ratio <= 1.), new, old)

    _, *carry = lax.while_loop(cond_fun, body_fun, [0] + carry)
    _, _, t, _, last_t, interp_coeff = carry
    relative_output_time = (target_t - last_t) / (t - last_t)
    y_target = jnp.polyval(interp_coeff, relative_output_time.astype(interp_coeff.dtype))
    return carry, y_target
Example #12
0
def polyval(p, x):
  p = _remove_jaxarray(p)
  x = _remove_jaxarray(x)
  return JaxArray(jnp.polyval(p, x))
Example #13
0
negative = utils.copy_docstring(tf.math.negative,
                                lambda x, name=None: np.negative(x))

# nextafter = utils.copy_docstring(
#     tf.math.nextafter,
#     lambda x1, x2, name=None: np.nextafter)

not_equal = utils.copy_docstring(tf.math.not_equal,
                                 lambda x, y, name=None: np.not_equal(x, y))

polygamma = utils.copy_docstring(
    tf.math.polygamma, lambda a, x, name=None: scipy_special.polygamma(a, x))

polyval = utils.copy_docstring(
    tf.math.polyval, lambda coeffs, x, name=None: np.polyval(coeffs, x))

pow = utils.copy_docstring(  # pylint: disable=redefined-builtin
    tf.math.pow,
    lambda x, y, name=None: np.power(x, y))

real = utils.copy_docstring(tf.math.real,
                            lambda input, name=None: np.real(input))

reciprocal = utils.copy_docstring(tf.math.reciprocal,
                                  lambda x, name=None: np.reciprocal(x))

reduce_all = utils.copy_docstring(
    tf.math.reduce_all,
    lambda input_tensor, axis=None, keepdims=False, name=None: (  # pylint: disable=g-long-lambda
        np.all(input_tensor, _astuple(axis), keepdims=keepdims)))
Example #14
0
def polyfitval(x, y, deg):
    return vmap(lambda x, y: jnp.polyval(polyfit(x, y, deg), x),
                in_axes=-1,
                out_axes=-1)(x, y)
Example #15
0
def response_poly(theta: np.ndarray, x: np.ndarray) -> np.ndarray:
    """The response function."""
    return np.polyval(theta, x)
Example #16
0
def dist_dependent_params(dr, kwargs, key):
    param = jnp.polyval(kwargs[key], dr)
    return param