def interp(x, xp, fp): """ Simple equivalent of np.interp that compute a linear interpolation. We are not doing any checks, so make sure your query points are lying inside the array. TODO: Implement proper interpolation! x, xp, fp need to be 1d arrays """ # First we find the nearest neighbour ind = np.argmin((x - xp) ** 2) # Perform linear interpolation ind = np.clip(ind, 1, len(xp) - 2) xi = xp[ind] # Figure out if we are on the right or the left of nearest s = np.sign(np.clip(x, xp[1], xp[-2]) - xi).astype(np.int64) a = (fp[ind + np.copysign(1, s)] - fp[ind]) / ( xp[ind + np.copysign(1, s)] - xp[ind] ) b = fp[ind] - a * xp[ind] return a * x + b
def get_boundaries_intersections(z: jnp.ndarray, d: jnp.ndarray, trust_radius: Union[float, jnp.ndarray]): """ ported from scipy Solve the scalar quadratic equation ||z + t d|| == trust_radius. This is like a line-sphere intersection. Return the two values of t, sorted from low to high. """ a = _dot(d, d) b = 2 * _dot(z, d) c = _dot(z, z) - trust_radius**2 sqrt_discriminant = jnp.sqrt(b * b - 4 * a * c) # The following calculation is mathematically # equivalent to: # ta = (-b - sqrt_discriminant) / (2*a) # tb = (-b + sqrt_discriminant) / (2*a) # but produce smaller round off errors. # Look at Matrix Computation p.97 # for a better justification. aux = b + jnp.copysign(sqrt_discriminant, b) ta = -aux / (2 * a) tb = -2 * c / aux # (ta, tb) if ta < tb else (tb, ta) ra = jnp.where(ta < tb, ta, tb) rb = jnp.where(ta < tb, tb, ta) return (ra, rb)
def copysign(x1, x2): if isinstance(x1, JaxArray): x1 = x1.value if isinstance(x2, JaxArray): x2 = x2.value return JaxArray(jnp.copysign(x1, x2))