def get_params(self, y_dict, design, target_labels): y = torch.cat(list(y_dict.values()), dim=-1) coefficient_labels = [label for label in target_labels if label != self.tau_label] mu, scale_tril = self.linear_model_formula(y, design, coefficient_labels) mu_vec = torch.cat(list(mu.values()), dim=-1) yty = rvv(y, y) ytxmu = rvv(y, rmv(design, mu_vec)) beta = self.b0 + .5*(yty - ytxmu) return mu, scale_tril, self.alpha, beta
def dv_critic(design, trace, observation_labels, target_labels): y_dict = {l: trace.nodes[l]["value"] for l in observation_labels} theta_dict = {l: trace.nodes[l]["value"] for l in target_labels} x = torch.cat(list(theta_dict.values()) + list(y_dict.values()), dim=-1) B = pyro.param("B", torch.zeros(5, 5)) return rvv(x, rmv(B, x))
def test_rvv(a, b): assert_equal(rvv(a, b), torch.dot(a, b), prec=1e-8) batched_a = lexpand(a, 5, 4) batched_b = lexpand(b, 5, 4) expected_ab = lexpand(torch.dot(a, b), 5, 4) assert_equal(rvv(batched_a, batched_b), expected_ab, prec=1e-8)