def linear_model_formula(self, y, design, target_labels): if self.use_softplus: mu = {l: rmv(self.softplus(self.regressor[l]), y) for l in target_labels} else: mu = {l: rmv(self.regressor[l], y) for l in target_labels} scale_tril = {l: rtril(self.scale_tril[l]) for l in target_labels} return mu, scale_tril
def linear_model_formula(self, y, design, target_labels): tikhonov_diag = torch.diag(self.softplus(self.tikhonov_diag)) xtx = torch.matmul(design.transpose(-1, -2), design) + tikhonov_diag xtxi = rinverse(xtx, sym=True) mu = rmv(xtxi, rmv(design.transpose(-1, -2), y)) # Extract sub-indices mu = tensor_to_dict(self.w_sizes, mu, subset=target_labels) scale_tril = {l: rtril(self.scale_tril[l]) for l in target_labels} return mu, scale_tril
def test_rtril(): A = torch.tensor([[1., 2.], [-2., 0]]) assert_equal(rtril(A), torch.tril(A), prec=1e-8) expanded = lexpand(A, 5, 4) expected = lexpand(torch.tril(A), 5, 4) assert_equal(rtril(expanded), expected, prec=1e-8)