def test_tensor_to_dict(): sizes = OrderedDict([("a", 2), ("b", 2), ("c", 2)]) vector = torch.tensor([1., 2, 3, 4, 5, 6]) assert_equal(tensor_to_dict(sizes, vector), {"a": torch.tensor([1., 2.]), "b": torch.tensor([3., 4.]), "c": torch.tensor([5., 6.])}) assert_equal(tensor_to_dict(sizes, vector, subset=["b"]), {"b": torch.tensor([3., 4.])})
def linear_model_formula(self, y, design, target_labels): tikhonov_diag = torch.diag(self.softplus(self.tikhonov_diag)) xtx = torch.matmul(design.transpose(-1, -2), design) + tikhonov_diag xtxi = rinverse(xtx, sym=True) mu = rmv(xtxi, rmv(design.transpose(-1, -2), y)) # Extract sub-indices mu = tensor_to_dict(self.w_sizes, mu, subset=target_labels) scale_tril = {l: rtril(self.scale_tril[l]) for l in target_labels} return mu, scale_tril
def __init__(self, d, w_sizes, tau_label=None, init_value=0.1, **kwargs): super().__init__() # start in train mode self.train() if not isinstance(d, (tuple, list, torch.Tensor)): d = (d,) self.means = nn.ParameterDict() if tau_label is not None: w_sizes[tau_label] = 1 for l, mu_l in tensor_to_dict(w_sizes, init_value*torch.ones(*(d + (sum(w_sizes.values()), )))).items(): self.means[l] = nn.Parameter(mu_l) self.scale_trils = {} self.w_sizes = w_sizes