def linear_model_ground_truth(model, design, observation_labels, target_labels, eig=True): if isinstance(target_labels, str): target_labels = [target_labels] w_sd = torch.cat(list(model.w_sds.values()), dim=-1) prior_cov = torch.diag(w_sd**2) design_shape = design.shape posterior_covs = [ analytic_posterior_cov(prior_cov, x, model.obs_sd) for x in torch.unbind( design.reshape(-1, design_shape[-2], design_shape[-1])) ] target_indices = get_indices(target_labels, tensors=model.w_sds) target_posterior_covs = [ S[target_indices, :][:, target_indices] for S in posterior_covs ] output = torch.tensor([ 0.5 * torch.logdet(2 * math.pi * math.e * C) for C in target_posterior_covs ]) if eig: prior_entropy = mean_field_entropy(model, [design], whitelist=target_labels) output = prior_entropy - output return output.reshape(design.shape[:-2])
def true_ape(ns): """Analytic APE""" true_ape = [] prior_cov = torch.diag(prior_sds**2) designs = [group_assignment_matrix(torch.tensor([n1, N-n1])) for n1 in ns] for i in range(len(ns)): x = designs[i] posterior_cov = analytic_posterior_cov(prior_cov, x, torch.tensor(1.)) true_ape.append(0.5*torch.logdet(2*np.pi*np.e*posterior_cov)) return torch.tensor(true_ape)