Пример #1
0
def loss_MAP(mu,
             tau_unc,
             D,
             i0,
             i1,
             mu0,
             beta=1.0,
             gamma_shape=1.0,
             gamma_rate=1.0,
             alpha=1.0):
    mu_i, mu_j = mu[i0], mu[i1]
    tau = EPSILON + jax.nn.softplus(SCALE * tau_unc)
    tau_i, tau_j = tau[i0], tau[i1]

    tau_ij_inv = tau_i * tau_j / (tau_i + tau_j)
    log_tau_ij_inv = jnp.log(tau_i) + jnp.log(tau_j) - jnp.log(tau_i + tau_j)

    d = jnp.linalg.norm(mu_i - mu_j, ord=2, axis=1, keepdims=1)

    log_llh = (jnp.log(D) + log_tau_ij_inv - 0.5 * tau_ij_inv * (D - d)**2 +
               jnp.log(i0e(tau_ij_inv * D * d)))

    # index of points in prior
    log_mu = multivariate_normal.logpdf(mu, mean=mu0, cov=beta * jnp.eye(2))
    log_tau = gamma.logpdf(tau, a=gamma_shape, scale=1.0 / gamma_rate)

    return jnp.sum(log_llh) + jnp.sum(log_mu) + jnp.sum(log_tau)
Пример #2
0
def _ncx2_log_pdf(x, df, nc):
    # We use (xs**2 + ns**2)/2 = (xs - ns)**2/2  + xs*ns, and include the
    # factor of exp(-xs*ns) into the ive function to improve numerical
    # stability at large values of xs. See also `rice.pdf`.
    # hhttps://github.com/scipy/scipy/blob/v1.5.2/scipy/stats/_distn_infrastructure.py#L556

    # rice.pdf(x, b) = x * exp(-(x**2+b**2)/2) * I[0](x*b)
    #
    # We use (x**2 + b**2)/2 = ((x-b)**2)/2 + xb.
    # rice.pdf(x, b) = x * [exp( - (x-b)**2)/2 ) / exp(-xb)] * [exp(-xb) * I[0](xb)]

    # The factor of np.exp(-xb) is then included in the i0e function
    # in place of the modified Bessel function, i0, improving
    # numerical stability for large values of xb.

    # df2 = df/2.0 - 1.0
    # xs, ns = np.sqrt(x), np.sqrt(nc)
    # res = xlogy(df2/2.0, x/nc) - 0.5*(xs - ns)**2
    # res += np.log(ive(df2, xs*ns) / 2.0)

    # assert jnp.all(nc > 0), "Encouting non-positive nc params of X-square dist."
    xs, ns = jnp.sqrt(x + EPSILON), jnp.sqrt(nc + EPSILON)
    res = -jnp.log(2.0) - 0.5 * (xs - ns)**2
    res = res + jnp.log(i0e(xs * ns))
    # if df == 2:
    #     res = res + jnp.log(i0e(xs * ns))
    # elif df == 4:
    #     res = res + 0.5 * (jnp.log(x) - jnp.log(nc))
    #     res = res + jnp.log(i1e(xs * ns))
    # else:
    #     raise ValueError("logpdf of NonCentral X-square only support dof of 2 or 4")
    return res.reshape(())
Пример #3
0
def loss_one_pair(mu_i, mu_j, s_i, s_j, D, n_components):
    s_ij = s_i + s_j + EPSILON
    d_ij = jnp.linalg.norm(mu_i - mu_j) + EPSILON

    log_llh = (jnp.log(D) - jnp.log(s_ij) - 0.5 *
               (D * D + d_ij * d_ij) / s_ij + jnp.log(i0e(d_ij * D / s_ij)))
    return -log_llh
Пример #4
0
def log_likelihood_one_pair(mu_i, mu_j, tau_i, tau_j, D):
    tau_ij = tau_i * tau_j / (tau_i + tau_j)
    d_ij = jnp.linalg.norm(mu_i - mu_j)

    log_llh = (jnp.log(D) + jnp.log(tau_ij) - 0.5 * tau_ij * (D - d_ij)**2 +
               jnp.log(i0e(tau_ij * D * d_ij)))
    # print("[DEBUG] Log llh: ", tau_ij_inv.shape, d_ij.shape, log_llh.shape)
    return log_llh
Пример #5
0
def loss_MAP(params, D, i0, i1, mu0, sigma0, sigma_local, alpha):
    mu = params[0]
    mu_i, mu_j = mu[i0], mu[i1]
    sigma_ij = sigma_local[i0] + sigma_local[i1]
    d = jnp.linalg.norm(mu_i - mu_j, ord=2, axis=1, keepdims=1)

    log_llh = (jnp.log(D) - jnp.log(sigma_ij) - 0.5 *
               (D - d + EPSILON)**2 / sigma_ij +
               jnp.log(i0e(D * d + EPSILON / sigma_ij)))
    log_mu_all = log_prior_mu_batch(mu, mu0, sigma0)
    # using Adam to minize the loss (maximize MAP)
    # return -jnp.sum(log_llh) - alpha * jnp.sum(log_mu_all)
    return -0.5 * jnp.mean(log_llh) - jnp.mean(log_mu_all)
Пример #6
0
 def variance(self):
     """Computes circular variance of distribution"""
     return jnp.broadcast_to(
         1.0 - i1e(self.concentration) / i0e(self.concentration),
         self.batch_shape)
Пример #7
0
 def log_prob(self, value):
     return -(jnp.log(2 * jnp.pi) + jnp.log(i0e(
         self.concentration))) + self.concentration * (jnp.cos(
             (value - self.loc) % (2 * jnp.pi)) - 1)
Пример #8
0
def _ncx2_log_pdf(x, df, nc):
    xs, ns = jnp.sqrt(x + EPSILON), jnp.sqrt(nc + EPSILON)
    res = -jnp.log(2.0) - 0.5 * (xs - ns)**2
    res = res + jnp.log(i0e(xs * ns))
    return res.reshape(())