Ejemplo n.º 1
0
def _atan2_taylor(primals_in, series_in):
  x, y = primals_in
  primal_out = lax.atan2(x, y)

  x, series = jet(lax.div, primals_in, series_in)
  c0, cs = jet(lambda x: lax.div(1, 1 + lax.square(x)), (x, ), (series, ))
  c = [c0] + cs
  u = [x] + series
  v = [primal_out] + [None] * len(series)
  for k in range(1, len(v)):
    v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
  primal_out, *series_out = v
  return primal_out, series_out
Ejemplo n.º 2
0
    def sample(self, key, sample_shape=()):
        """
        ** References: **
            1. A New Unified Approach for the Simulation of aWide Class of Directional Distributions
               John T. Kent, Asaad M. Ganeiber & Kanti V. Mardia (2018)
        """
        phi_key, psi_key = random.split(key)

        corr = self.correlation
        conc = jnp.stack((self.phi_concentration, self.psi_concentration))

        eig = 0.5 * (conc[0] - corr**2 / conc[1])
        eig = jnp.stack((jnp.zeros_like(eig), eig))
        eigmin = jnp.where(eig[1] < 0, eig[1],
                           jnp.zeros_like(eig[1], dtype=eig.dtype))
        eig = eig - eigmin
        b0 = self._bfind(eig)

        total = _numel(sample_shape)
        phi_den = log_I1(0, conc[1]).squeeze(0)
        phi_shape = (total, 2, _numel(self.batch_shape))
        phi_state = Sine._phi_marginal(phi_shape, phi_key, conc, corr, eig, b0,
                                       eigmin, phi_den)

        # if not jnp.all(phi_state.done):
        #     raise ValueError("maximum number of iterations exceeded; "
        #                      "try increasing `SineBivariateVonMises.max_sample_iter`")

        phi = lax.atan2(phi_state.phi[:, 1:], phi_state.phi[:, :1])

        alpha = jnp.sqrt(conc[1]**2 + (corr * jnp.sin(phi))**2)
        beta = lax.atan(corr / conc[1] * jnp.sin(phi))

        psi = VonMises(beta, alpha).sample(psi_key)

        phi_psi = jnp.concatenate(
            (
                (phi + self.phi_loc + pi) % (2 * pi) - pi,
                (psi + self.psi_loc + pi) % (2 * pi) - pi,
            ),
            axis=1,
        )
        phi_psi = jnp.transpose(phi_psi, (0, 2, 1))
        return phi_psi.reshape(*sample_shape, *self.batch_shape,
                               *self.event_shape)
Ejemplo n.º 3
0
def angle(x):
    if iscomplexobj(x):
        return lax.atan2(lax.imag(x), lax.real(x))
    else:
        return zeros_like(x)