def log_pred_density(model, samples, *args, **kwargs): # waic score of posterior samples log_lk = log_likelihood(model, samples, *args, **kwargs)['y'] ll = log_lk.sum(-1) S = ll.shape[0] lppd = nn.logsumexp(ll, 0) - jnp.log(S) p_waic = jnp.var(ll, axis=0, ddof=1) return lppd - p_waic, log_lk
def predict(model, at_bats, hits, z, rng_key, player_names, train=True): header = model.__name__ + (' - TRAIN' if train else ' - TEST') predictions = Predictive(model, posterior_samples=z)(rng_key, at_bats)['obs'] print_results('=' * 30 + header + '=' * 30, predictions, player_names, at_bats, hits) if not train: post_loglik = log_likelihood(model, z, at_bats, hits)['obs'] # computes expected log predictive density at each data point exp_log_density = logsumexp(post_loglik, axis=0) - jnp.log(jnp.shape(post_loglik)[0]) # reports log predictive density of all test points print('\nLog pointwise predictive density: {:.2f}\n'.format(exp_log_density.sum()))
def _compute_log_likelihood_null(posterior_samples, data): return log_likelihood(model_null, posterior_samples, **data)["obs"]