예제 #1
0
    def get_params(self, y_dict, design, target_labels):

        y = torch.cat(list(y_dict.values()), dim=-1)

        coefficient_labels = [label for label in target_labels if label != self.tau_label]
        mu, scale_tril = self.linear_model_formula(y, design, coefficient_labels)
        mu_vec = torch.cat(list(mu.values()), dim=-1)

        yty = rvv(y, y)
        ytxmu = rvv(y, rmv(design, mu_vec))
        beta = self.b0 + .5*(yty - ytxmu)

        return mu, scale_tril, self.alpha, beta
예제 #2
0
def dv_critic(design, trace, observation_labels, target_labels):
    y_dict = {l: trace.nodes[l]["value"] for l in observation_labels}
    theta_dict = {l: trace.nodes[l]["value"] for l in target_labels}
    x = torch.cat(list(theta_dict.values()) + list(y_dict.values()), dim=-1)

    B = pyro.param("B", torch.zeros(5, 5))
    return rvv(x, rmv(B, x))
예제 #3
0
파일: test_util.py 프로젝트: yufengwa/pyro
def test_rvv(a, b):
    assert_equal(rvv(a, b), torch.dot(a, b), prec=1e-8)
    batched_a = lexpand(a, 5, 4)
    batched_b = lexpand(b, 5, 4)
    expected_ab = lexpand(torch.dot(a, b), 5, 4)
    assert_equal(rvv(batched_a, batched_b), expected_ab, prec=1e-8)