def _atan2_taylor(primals_in, series_in): x, y = primals_in primal_out = lax.atan2(x, y) x, series = jet(lax.div, primals_in, series_in) c0, cs = jet(lambda x: lax.div(1, 1 + lax.square(x)), (x, ), (series, )) c = [c0] + cs u = [x] + series v = [primal_out] + [None] * len(series) for k in range(1, len(v)): v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1)) primal_out, *series_out = v return primal_out, series_out
def sample(self, key, sample_shape=()): """ ** References: ** 1. A New Unified Approach for the Simulation of aWide Class of Directional Distributions John T. Kent, Asaad M. Ganeiber & Kanti V. Mardia (2018) """ phi_key, psi_key = random.split(key) corr = self.correlation conc = jnp.stack((self.phi_concentration, self.psi_concentration)) eig = 0.5 * (conc[0] - corr**2 / conc[1]) eig = jnp.stack((jnp.zeros_like(eig), eig)) eigmin = jnp.where(eig[1] < 0, eig[1], jnp.zeros_like(eig[1], dtype=eig.dtype)) eig = eig - eigmin b0 = self._bfind(eig) total = _numel(sample_shape) phi_den = log_I1(0, conc[1]).squeeze(0) phi_shape = (total, 2, _numel(self.batch_shape)) phi_state = Sine._phi_marginal(phi_shape, phi_key, conc, corr, eig, b0, eigmin, phi_den) # if not jnp.all(phi_state.done): # raise ValueError("maximum number of iterations exceeded; " # "try increasing `SineBivariateVonMises.max_sample_iter`") phi = lax.atan2(phi_state.phi[:, 1:], phi_state.phi[:, :1]) alpha = jnp.sqrt(conc[1]**2 + (corr * jnp.sin(phi))**2) beta = lax.atan(corr / conc[1] * jnp.sin(phi)) psi = VonMises(beta, alpha).sample(psi_key) phi_psi = jnp.concatenate( ( (phi + self.phi_loc + pi) % (2 * pi) - pi, (psi + self.psi_loc + pi) % (2 * pi) - pi, ), axis=1, ) phi_psi = jnp.transpose(phi_psi, (0, 2, 1)) return phi_psi.reshape(*sample_shape, *self.batch_shape, *self.event_shape)
def angle(x): if iscomplexobj(x): return lax.atan2(lax.imag(x), lax.real(x)) else: return zeros_like(x)