コード例 #1
0
def interp(x, xp, fp):
    """
  Simple equivalent of np.interp that compute a linear interpolation.

  We are not doing any checks, so make sure your query points are lying
  inside the array.

  TODO: Implement proper interpolation!

  x, xp, fp need to be 1d arrays
  """
    # First we find the nearest neighbour
    ind = np.argmin((x - xp) ** 2)

    # Perform linear interpolation
    ind = np.clip(ind, 1, len(xp) - 2)

    xi = xp[ind]
    # Figure out if we are on the right or the left of nearest
    s = np.sign(np.clip(x, xp[1], xp[-2]) - xi).astype(np.int64)
    a = (fp[ind + np.copysign(1, s)] - fp[ind]) / (
        xp[ind + np.copysign(1, s)] - xp[ind]
    )
    b = fp[ind] - a * xp[ind]
    return a * x + b
コード例 #2
0
def get_boundaries_intersections(z: jnp.ndarray, d: jnp.ndarray,
                                 trust_radius: Union[float, jnp.ndarray]):
    """
  ported from scipy

  Solve the scalar quadratic equation ||z + t d|| == trust_radius.
  This is like a line-sphere intersection.
  Return the two values of t, sorted from low to high.
  """
    a = _dot(d, d)
    b = 2 * _dot(z, d)
    c = _dot(z, z) - trust_radius**2
    sqrt_discriminant = jnp.sqrt(b * b - 4 * a * c)

    # The following calculation is mathematically
    # equivalent to:
    # ta = (-b - sqrt_discriminant) / (2*a)
    # tb = (-b + sqrt_discriminant) / (2*a)
    # but produce smaller round off errors.
    # Look at Matrix Computation p.97
    # for a better justification.
    aux = b + jnp.copysign(sqrt_discriminant, b)
    ta = -aux / (2 * a)
    tb = -2 * c / aux

    # (ta, tb) if ta < tb else (tb, ta)
    ra = jnp.where(ta < tb, ta, tb)
    rb = jnp.where(ta < tb, tb, ta)
    return (ra, rb)
コード例 #3
0
def copysign(x1, x2):
  if isinstance(x1, JaxArray): x1 = x1.value
  if isinstance(x2, JaxArray): x2 = x2.value
  return JaxArray(jnp.copysign(x1, x2))