Example #1
0
def test_get_indices_sizes():
    sizes = OrderedDict([("a", 2), ("b", 2), ("c", 2)])
    assert_equal(get_indices(["b"], sizes=sizes), torch.tensor([2, 3]))
    assert_equal(get_indices(["b", "c"], sizes=sizes), torch.tensor([2, 3, 4, 5]))
    tensors = OrderedDict([("a", torch.ones(2)), ("b", torch.ones(2)), ("c", torch.ones(2))])
    assert_equal(get_indices(["b"], tensors=tensors), torch.tensor([2, 3]))
    assert_equal(get_indices(["b", "c"], tensors=tensors), torch.tensor([2, 3, 4, 5]))
Example #2
0
def linear_model_ground_truth(model,
                              design,
                              observation_labels,
                              target_labels,
                              eig=True):
    if isinstance(target_labels, str):
        target_labels = [target_labels]

    w_sd = torch.cat(list(model.w_sds.values()), dim=-1)
    prior_cov = torch.diag(w_sd**2)
    design_shape = design.shape
    posterior_covs = [
        analytic_posterior_cov(prior_cov, x, model.obs_sd)
        for x in torch.unbind(
            design.reshape(-1, design_shape[-2], design_shape[-1]))
    ]
    target_indices = get_indices(target_labels, tensors=model.w_sds)
    target_posterior_covs = [
        S[target_indices, :][:, target_indices] for S in posterior_covs
    ]
    output = torch.tensor([
        0.5 * torch.logdet(2 * math.pi * math.e * C)
        for C in target_posterior_covs
    ])
    if eig:
        prior_entropy = mean_field_entropy(model, [design],
                                           whitelist=target_labels)
        output = prior_entropy - output

    return output.reshape(design.shape[:-2])
Example #3
0
File: util.py Project: zyxue/pyro
def linear_model_ground_truth(model,
                              design,
                              observation_labels,
                              target_labels,
                              eig=True):
    if isinstance(target_labels, str):
        target_labels = [target_labels]

    w_sd = torch.cat(list(model.w_sds.values()), dim=-1)
    prior_cov = torch.diag(w_sd**2)
    posterior_covs = [
        analytic_posterior_cov(prior_cov, x, model.obs_sd)
        for x in torch.unbind(design)
    ]
    target_indices = get_indices(target_labels, tensors=model.w_sds)
    target_posterior_covs = [
        S[target_indices, :][:, target_indices] for S in posterior_covs
    ]
    if eig:
        prior_entropy = lm_H_prior(model, design, observation_labels,
                                   target_labels)
        return prior_entropy - torch.tensor([
            0.5 * torch.logdet(2 * math.pi * math.e * C)
            for C in target_posterior_covs
        ])
    else:
        return torch.tensor([
            0.5 * torch.logdet(2 * math.pi * math.e * C)
            for C in target_posterior_covs
        ])
Example #4
0
File: util.py Project: zyxue/pyro
def lm_H_prior(model, design, observation_labels, target_labels):
    if isinstance(target_labels, str):
        target_labels = [target_labels]

    w_sd = torch.cat(list(model.w_sds.values()), dim=-1)
    prior_cov = torch.diag(w_sd**2)
    target_indices = get_indices(target_labels, tensors=model.w_sds)
    target_prior_covs = prior_cov[target_indices, :][:, target_indices]
    return 0.5 * torch.logdet(2 * math.pi * math.e * target_prior_covs)