def test_get_indices_sizes(): sizes = OrderedDict([("a", 2), ("b", 2), ("c", 2)]) assert_equal(get_indices(["b"], sizes=sizes), torch.tensor([2, 3])) assert_equal(get_indices(["b", "c"], sizes=sizes), torch.tensor([2, 3, 4, 5])) tensors = OrderedDict([("a", torch.ones(2)), ("b", torch.ones(2)), ("c", torch.ones(2))]) assert_equal(get_indices(["b"], tensors=tensors), torch.tensor([2, 3])) assert_equal(get_indices(["b", "c"], tensors=tensors), torch.tensor([2, 3, 4, 5]))
def linear_model_ground_truth(model, design, observation_labels, target_labels, eig=True): if isinstance(target_labels, str): target_labels = [target_labels] w_sd = torch.cat(list(model.w_sds.values()), dim=-1) prior_cov = torch.diag(w_sd**2) design_shape = design.shape posterior_covs = [ analytic_posterior_cov(prior_cov, x, model.obs_sd) for x in torch.unbind( design.reshape(-1, design_shape[-2], design_shape[-1])) ] target_indices = get_indices(target_labels, tensors=model.w_sds) target_posterior_covs = [ S[target_indices, :][:, target_indices] for S in posterior_covs ] output = torch.tensor([ 0.5 * torch.logdet(2 * math.pi * math.e * C) for C in target_posterior_covs ]) if eig: prior_entropy = mean_field_entropy(model, [design], whitelist=target_labels) output = prior_entropy - output return output.reshape(design.shape[:-2])
def linear_model_ground_truth(model, design, observation_labels, target_labels, eig=True): if isinstance(target_labels, str): target_labels = [target_labels] w_sd = torch.cat(list(model.w_sds.values()), dim=-1) prior_cov = torch.diag(w_sd**2) posterior_covs = [ analytic_posterior_cov(prior_cov, x, model.obs_sd) for x in torch.unbind(design) ] target_indices = get_indices(target_labels, tensors=model.w_sds) target_posterior_covs = [ S[target_indices, :][:, target_indices] for S in posterior_covs ] if eig: prior_entropy = lm_H_prior(model, design, observation_labels, target_labels) return prior_entropy - torch.tensor([ 0.5 * torch.logdet(2 * math.pi * math.e * C) for C in target_posterior_covs ]) else: return torch.tensor([ 0.5 * torch.logdet(2 * math.pi * math.e * C) for C in target_posterior_covs ])
def lm_H_prior(model, design, observation_labels, target_labels): if isinstance(target_labels, str): target_labels = [target_labels] w_sd = torch.cat(list(model.w_sds.values()), dim=-1) prior_cov = torch.diag(w_sd**2) target_indices = get_indices(target_labels, tensors=model.w_sds) target_prior_covs = prior_cov[target_indices, :][:, target_indices] return 0.5 * torch.logdet(2 * math.pi * math.e * target_prior_covs)