Ejemplo n.º 1
0
def norm_logbicop_diag_approx(log_u,rho):
    eps = 1e-6
    log_u = jnp.clip(log_u,jnp.log(eps),jnp.log(1-eps))
    ind_true = jnp.where(log_u<=jnp.log(0.5),x = 1,y = 0) #check if u <0.5
    log_u = ind_true*log_u + (1-ind_true)*jnp.log1p(-jnp.exp(log_u)) #replaces log(u) with log(1-u) if less than 0.5

    u = jnp.exp(log_u)
    log_g = log_g_cop(u,rho) #for u<0.5
    log_interp = jnp.log((1+(rho/2)+(1/jnp.pi)*jnp.arcsin(rho)) + u*((2/jnp.pi)*jnp.arcsin(rho)- rho))
    logbicop = log_u + log_g +log_interp

    #add 2u-1 if u >0.5
    logbicop = jnp.log(ind_true*jnp.exp(logbicop)+ (1-ind_true)*((1-2*u)+jnp.exp(logbicop)))

    return logbicop
Ejemplo n.º 2
0
def _arcsin(x, do_backprop):
    if do_backprop:
        # https://github.com/google/jax/issues/654
        x = np.where(np.abs(x) >= 1, np.sign(x), x)
    else:
        x = np.clip(x, -1, 1)
    return np.arcsin(x)
Ejemplo n.º 3
0
    def nngp_fn(cov12, var1, var2):
      if 'Identity' in name:
        res = cov12

      elif 'Erf' in name:
        prod = (1 + 2 * var1) * (1 + 2 * var2)
        res = np.arcsin(2 * cov12 / np.sqrt(prod)) * 2 / np.pi

      elif 'Sin' in name:
        sum_ = (var1 + var2)
        s1 = np.exp((-0.5 * sum_ + cov12))
        s2 = np.exp((-0.5 * sum_ - cov12))
        res = (s1 - s2) / 2

      elif 'Relu' in name:
        prod = var1 * var2
        sqrt = np.sqrt(np.maximum(prod - cov12 ** 2, 1e-30))
        angles = np.arctan2(sqrt, cov12)
        dot_sigma = (1 - angles / np.pi) / 2
        res = sqrt / (2 * np.pi) + dot_sigma * cov12

      else:
        raise NotImplementedError(name)

      return res
Ejemplo n.º 4
0
def get_roll_pitch_jax(y_in):
    sq1, sq2, sq3, sq4, sq5, sq6, sq7, cq1, cq2, cq3, cq4, cq5, cq6, cq7 = get_sin_cos_jax(y_in)
    r_32 = -cq7*(cq5*sq2*sq3 - sq5*(cq2*sq4 - cq3*cq4*sq2)) - sq7*(cq6*(cq5*(cq2*sq4 - cq3*cq4*sq2) + sq2*sq3*sq5) + sq6*(cq2*cq4 + cq3*sq2*sq4))
    r_33 = -cq6*(cq2*cq4 + cq3*sq2*sq4) + sq6*(cq5*(cq2*sq4 - cq3*cq4*sq2) + sq2*sq3*sq5)
    r_31 = cq7*(cq6*(cq5*(cq2*sq4 - cq3*cq4*sq2) + sq2*sq3*sq5) + sq6*(cq2*cq4 + cq3*sq2*sq4)) - sq7*(cq5*sq2*sq3 - sq5*(cq2*sq4 - cq3*cq4*sq2))

    return jnp.arctan2(r_32, r_33), -jnp.arcsin(r_31)
Ejemplo n.º 5
0
def _transform_kernels_erf(kernels, do_backprop):
    """Compute new kernels after an `Erf` layer."""
    var1, nngp, var2, ntk, _, is_height_width = kernels
    _var1_denom = 1 + 2 * var1
    _var2_denom = None if var2 is None else 1 + 2 * var2
    prod = _get_var_prod(_var1_denom, nngp, _var2_denom)

    dot_sigma = 4 / (np.pi * np.sqrt(prod - 4 * nngp**2))
    if ntk is not None:
        ntk *= dot_sigma

    nngp = _arcsin(2 * nngp / np.sqrt(prod), do_backprop) * 2 / np.pi

    var1 = np.arcsin(2 * var1 / _var1_denom) * 2 / np.pi
    if var2 is not None:
        var2 = np.arcsin(2 * var2 / _var2_denom) * 2 / np.pi

    return Kernel(var1, nngp, var2, ntk, False, is_height_width)
Ejemplo n.º 6
0
    def compute_pitch_radians(self) -> jnp.ndarray:
        """Compute pitch angle. Uses the ZYX mobile robot convention.

        Returns:
            Euler angle in radians.
        """
        # https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Quaternion_to_Euler_angles_conversion
        q0, q1, q2, q3 = self.wxyz
        return jnp.arcsin(2 * (q0 * q2 - q3 * q1))
Ejemplo n.º 7
0
def forward_pass(x_trj, u_trj, k_trj, K_trj):
    u_trj = np.arcsin(np.sin(u_trj))
    
    x_trj_new = np.empty_like(x_trj)
    x_trj_new = jax.ops.index_update(x_trj_new, jax.ops.index[0], x_trj[0])
    u_trj_new = np.empty_like(u_trj)
    
    x_trj, u_trj, k_trj, K_trj, x_trj_new, u_trj_new = lax.fori_loop(
        0, TIME_STEPS-1, forward_pass_looper, [x_trj, u_trj, k_trj, K_trj, x_trj_new, u_trj_new]
    )

    return x_trj_new, u_trj_new
Ejemplo n.º 8
0
def xyz2equirect(xyz):
    """
    Convert unit vector to equirectangular coordinate,
    inverse of equirect2xyz
    Args:
        xyz: jnp.ndarray [..., 3] unit vectors
    Returns:
        uv: jnp.ndarray [...] coordinates (x, y) in image space in [-1.0, 1.0]
    """
    lat = jnp.arcsin(jnp.clip(xyz[..., 1], -1.0, 1.0))
    lon = jnp.arctan2(xyz[..., 0], xyz[..., 2])
    x = lon / jnp.pi
    y = 2.0 * lat / jnp.pi
    return jnp.stack([x, y], axis=-1)
Ejemplo n.º 9
0
def arcsin(x):
  if isinstance(x, JaxArray): x = x.value
  return JaxArray(jnp.arcsin(x))
Ejemplo n.º 10
0
def arcsin(a: Numeric):
    return jnp.arcsin(a)
Ejemplo n.º 11
0
class JaxBox(qml.math.TensorBox):
    """Implements the :class:`~.TensorBox` API for ``numpy.ndarray``.

    For more details, please refer to the :class:`~.TensorBox` documentation.
    """

    abs = wrap_output(lambda self: jnp.abs(self.data))
    angle = wrap_output(lambda self: jnp.angle(self.data))
    arcsin = wrap_output(lambda self: jnp.arcsin(self.data))
    cast = wrap_output(lambda self, dtype: jnp.array(self.data, dtype=dtype))
    expand_dims = wrap_output(
        lambda self, axis: jnp.expand_dims(self.data, axis=axis))
    ones_like = wrap_output(lambda self: jnp.ones_like(self.data))
    sqrt = wrap_output(lambda self: jnp.sqrt(self.data))
    sum = wrap_output(lambda self, axis=None, keepdims=False: jnp.sum(
        self.data, axis=axis, keepdims=keepdims))
    T = wrap_output(lambda self: self.data.T)
    take = wrap_output(lambda self, indices, axis=None: jnp.take(
        self.data, indices, axis=axis, mode="wrap"))

    def __init__(self, tensor):
        tensor = jnp.asarray(tensor)

        super().__init__(tensor)

    @staticmethod
    def astensor(tensor):
        return jnp.asarray(tensor)

    @staticmethod
    @wrap_output
    def concatenate(values, axis=0):
        return jnp.concatenate(JaxBox.unbox_list(values), axis=axis)

    @staticmethod
    @wrap_output
    def dot(x, y):
        x, y = JaxBox.unbox_list([x, y])
        x = jnp.asarray(x)
        y = jnp.asarray(y)

        if x.ndim == 0 and y.ndim == 0:
            return x * y

        if x.ndim == 2 and y.ndim == 2:
            return x @ y

        return jnp.dot(x, y)

    @property
    def interface(self):
        return "jax"

    def numpy(self):
        return self.data

    @property
    def requires_grad(self):
        return True

    @property
    def shape(self):
        return self.data.shape

    @staticmethod
    @wrap_output
    def stack(values, axis=0):
        return jnp.stack(JaxBox.unbox_list(values), axis=axis)

    @staticmethod
    @wrap_output
    def where(condition, x, y):
        return jnp.where(condition, *JaxBox.unbox_list([x, y]))
Ejemplo n.º 12
0
angle = utils.copy_docstring(tf.math.angle,
                             lambda input, name=None: np.angle(input))

argmax = utils.copy_docstring(
    tf.math.argmax,
    lambda input, axis=None, output_type=tf.int64, name=None: (  # pylint: disable=g-long-lambda
        np.argmax(input, axis=0 if axis is None else _astuple(axis)).astype(
            utils.numpy_dtype(output_type))))

argmin = utils.copy_docstring(
    tf.math.argmin,
    lambda input, axis=None, output_type=tf.int64, name=None: (  # pylint: disable=g-long-lambda
        np.argmin(input, axis=0 if axis is None else _astuple(axis)).astype(
            utils.numpy_dtype(output_type))))

asin = utils.copy_docstring(tf.math.asin, lambda x, name=None: np.arcsin(x))

asinh = utils.copy_docstring(tf.math.asinh, lambda x, name=None: np.arcsinh(x))

atan = utils.copy_docstring(tf.math.atan, lambda x, name=None: np.arctan(x))

atan2 = utils.copy_docstring(tf.math.atan2,
                             lambda y, x, name=None: np.arctan2(y, x))

atanh = utils.copy_docstring(tf.math.atanh, lambda x, name=None: np.arctanh(x))

bessel_i0 = utils.copy_docstring(tf.math.bessel_i0,
                                 lambda x, name=None: scipy_special.i0(x))

bessel_i0e = utils.copy_docstring(tf.math.bessel_i0e,
                                  lambda x, name=None: scipy_special.i0e(x))
Ejemplo n.º 13
0
def to_inf(param, bounds):
    a, b = bounds
    # print(f"a,b: {a,b}")
    x = (2.0 * param - a) / (b - a) - 1.0
    return jnp.arcsin(x)
Ejemplo n.º 14
0
def to_inf_vec(param, bounds):
    bounds = jnp.asarray(bounds)
    a, b = bounds[:, 0], bounds[:, 1]
    x = (2.0 * param - a) / (b - a) - 1.0
    return jnp.arcsin(x)
Ejemplo n.º 15
0
def sin_bijector_inv(x, a, b):
    return jnp.arcsin((x - b) / a)