def ll(sigma, theta, y): p = y.shape[-1] sc = jnp.sqrt(jnp.diag(sigma)) al = jnp.einsum('i,i->i', 1 / sc, theta) capital_phi = jnp.sum(norm.logcdf(jnp.matmul(al, y.T))) small_phi = jnp.sum(mvn.logpdf(y, mean=jnp.zeros(p), cov=sigma)) return -(2 + small_phi + capital_phi)
def logpdf(self, z): """Compute the logpdf from sample z.""" capital_phi = norm.logcdf(jnp.matmul(self.alpha, (z - self.loc).T)) small_phi = mvn.logpdf(z - self.loc, mean=jnp.zeros(shape=(self.k), ), cov=self.cov) return 2 + small_phi + capital_phi
def ll_chol(pars, y): p = y.shape[-1] X, theta = pars[:-p], pars[-p:] sigma = index_update(jnp.zeros(shape=(p, p)), jnp.triu_indices(p), X).T sigma = jnp.matmul(sigma, sigma.T) sc = jnp.sqrt(jnp.diag(sigma)) al = jnp.einsum('i,i->i', 1 / sc, theta) capital_phi = jnp.sum(norm.logcdf(jnp.matmul(al, y.T))) small_phi = jnp.sum(mvn.logpdf(y, mean=jnp.zeros(p), cov=sigma)) return -(2 + small_phi + capital_phi)
def log_cdf(self, value): return norm.logcdf(value, loc=self._loc, scale=self._scale)
def norm_logcdf(z): return norm.logcdf(z)
def probitloss(self, X, y, w): # NLL return -jnp.sum(y * jnorm.logcdf(jnp.dot(X, w))) - \ jnp.sum((1 - y) * jnorm.logcdf(-jnp.dot(X, w)))
def bernoulli_probit_lik(y, f): return y * norm.logcdf(f) + (1 - y) * norm.logcdf(-f)