def loss_MAP(mu, tau_unc, D, i0, i1, mu0, beta=1.0, gamma_shape=1.0, gamma_rate=1.0, alpha=1.0): mu_i, mu_j = mu[i0], mu[i1] tau = EPSILON + jax.nn.softplus(SCALE * tau_unc) tau_i, tau_j = tau[i0], tau[i1] tau_ij_inv = tau_i * tau_j / (tau_i + tau_j) log_tau_ij_inv = jnp.log(tau_i) + jnp.log(tau_j) - jnp.log(tau_i + tau_j) d = jnp.linalg.norm(mu_i - mu_j, ord=2, axis=1, keepdims=1) log_llh = (jnp.log(D) + log_tau_ij_inv - 0.5 * tau_ij_inv * (D - d)**2 + jnp.log(i0e(tau_ij_inv * D * d))) # index of points in prior log_mu = multivariate_normal.logpdf(mu, mean=mu0, cov=beta * jnp.eye(2)) log_tau = gamma.logpdf(tau, a=gamma_shape, scale=1.0 / gamma_rate) return jnp.sum(log_llh) + jnp.sum(log_mu) + jnp.sum(log_tau)
def _ncx2_log_pdf(x, df, nc): # We use (xs**2 + ns**2)/2 = (xs - ns)**2/2 + xs*ns, and include the # factor of exp(-xs*ns) into the ive function to improve numerical # stability at large values of xs. See also `rice.pdf`. # hhttps://github.com/scipy/scipy/blob/v1.5.2/scipy/stats/_distn_infrastructure.py#L556 # rice.pdf(x, b) = x * exp(-(x**2+b**2)/2) * I[0](x*b) # # We use (x**2 + b**2)/2 = ((x-b)**2)/2 + xb. # rice.pdf(x, b) = x * [exp( - (x-b)**2)/2 ) / exp(-xb)] * [exp(-xb) * I[0](xb)] # The factor of np.exp(-xb) is then included in the i0e function # in place of the modified Bessel function, i0, improving # numerical stability for large values of xb. # df2 = df/2.0 - 1.0 # xs, ns = np.sqrt(x), np.sqrt(nc) # res = xlogy(df2/2.0, x/nc) - 0.5*(xs - ns)**2 # res += np.log(ive(df2, xs*ns) / 2.0) # assert jnp.all(nc > 0), "Encouting non-positive nc params of X-square dist." xs, ns = jnp.sqrt(x + EPSILON), jnp.sqrt(nc + EPSILON) res = -jnp.log(2.0) - 0.5 * (xs - ns)**2 res = res + jnp.log(i0e(xs * ns)) # if df == 2: # res = res + jnp.log(i0e(xs * ns)) # elif df == 4: # res = res + 0.5 * (jnp.log(x) - jnp.log(nc)) # res = res + jnp.log(i1e(xs * ns)) # else: # raise ValueError("logpdf of NonCentral X-square only support dof of 2 or 4") return res.reshape(())
def loss_one_pair(mu_i, mu_j, s_i, s_j, D, n_components): s_ij = s_i + s_j + EPSILON d_ij = jnp.linalg.norm(mu_i - mu_j) + EPSILON log_llh = (jnp.log(D) - jnp.log(s_ij) - 0.5 * (D * D + d_ij * d_ij) / s_ij + jnp.log(i0e(d_ij * D / s_ij))) return -log_llh
def log_likelihood_one_pair(mu_i, mu_j, tau_i, tau_j, D): tau_ij = tau_i * tau_j / (tau_i + tau_j) d_ij = jnp.linalg.norm(mu_i - mu_j) log_llh = (jnp.log(D) + jnp.log(tau_ij) - 0.5 * tau_ij * (D - d_ij)**2 + jnp.log(i0e(tau_ij * D * d_ij))) # print("[DEBUG] Log llh: ", tau_ij_inv.shape, d_ij.shape, log_llh.shape) return log_llh
def loss_MAP(params, D, i0, i1, mu0, sigma0, sigma_local, alpha): mu = params[0] mu_i, mu_j = mu[i0], mu[i1] sigma_ij = sigma_local[i0] + sigma_local[i1] d = jnp.linalg.norm(mu_i - mu_j, ord=2, axis=1, keepdims=1) log_llh = (jnp.log(D) - jnp.log(sigma_ij) - 0.5 * (D - d + EPSILON)**2 / sigma_ij + jnp.log(i0e(D * d + EPSILON / sigma_ij))) log_mu_all = log_prior_mu_batch(mu, mu0, sigma0) # using Adam to minize the loss (maximize MAP) # return -jnp.sum(log_llh) - alpha * jnp.sum(log_mu_all) return -0.5 * jnp.mean(log_llh) - jnp.mean(log_mu_all)
def variance(self): """Computes circular variance of distribution""" return jnp.broadcast_to( 1.0 - i1e(self.concentration) / i0e(self.concentration), self.batch_shape)
def log_prob(self, value): return -(jnp.log(2 * jnp.pi) + jnp.log(i0e( self.concentration))) + self.concentration * (jnp.cos( (value - self.loc) % (2 * jnp.pi)) - 1)
def _ncx2_log_pdf(x, df, nc): xs, ns = jnp.sqrt(x + EPSILON), jnp.sqrt(nc + EPSILON) res = -jnp.log(2.0) - 0.5 * (xs - ns)**2 res = res + jnp.log(i0e(xs * ns)) return res.reshape(())