def __init__(self, parents: Tuple[Gaussian], child: Gaussian): VMPFactor.__init__(self) self._deterministic = True self.parents = parents self.child = child self.message_to_child = Gaussian.uniform(self.child.shape) self.message_to_parents = tuple( Gaussian.uniform(p.shape) for p in self.parents)
def __init__(self, parent: Gaussian, children: Tuple[Gaussian]): VMPFactor.__init__(self) self._deterministic = True self.parent = parent self.children = children self._chunks = tuple(c.shape[1] for c in children) self.message_to_children = tuple( Gaussian.uniform(c.shape) for c in children) self.message_to_parent = Gaussian.uniform(self.parent.shape)
def __init__(self, parent: Gaussian, index: torch.Tensor, child: Gaussian): VMPFactor.__init__(self) self._deterministic = True self.parent = parent self.index = index self.child = child self.which = torch.stack( [self.index == i for i in range(self.parent.shape[0])]) self.message_to_child = Gaussian.uniform(self.child.shape) self.message_to_parent = Gaussian.uniform(self.child.shape) self.message_to_parent_sum = Gaussian.uniform(self.parent.shape)
def __init__(self, parent: Gaussian, child: Gaussian, variance: torch.Tensor): # TODO: variance as tensor, variance in log scale VMPFactor.__init__(self) self._deterministic = False self._observed = False self._prior = False self.log_var = nn.Parameter(variance) self.parent = parent self.child = child self.message_to_child = Gaussian.uniform(self.shape) self.message_to_parent = Gaussian.uniform(self.shape)
def __init__(self, parent: Gaussian, child: Gaussian): VMPFactor.__init__(self) self._deterministic = True self.parent = parent self.child = child self.in_dim = self.parent.shape[1] self.out_dim = self.child.shape[1] dtype = self.parent.precision.dtype self.weight = nn.Parameter( torch.randn((self.in_dim, self.out_dim), dtype=dtype)) self.bias = nn.Parameter(torch.randn((self.out_dim, ), dtype=dtype)) self.message_to_child = Gaussian.uniform(self.child.shape) self.message_to_parent_sum = Gaussian.uniform(self.parent.shape) self.message_to_parent = Gaussian.uniform( (self.parent.shape[0], self.in_dim, self.out_dim))
def to_child(self): message_from_parent = self.parent / self.message_to_parent m = message_from_parent.mean v = message_from_parent.variance + self.log_var.exp().unsqueeze(0) message_to_child = Gaussian.from_array(m, v) self.child.update(self.message_to_child, message_to_child) self.message_to_child = message_to_child
def to_parent(self): message_from_child = self.child / self.message_to_child m = message_from_child.mean v = message_from_child.variance + self.log_var.exp().unsqueeze(0) message_to_parent = Gaussian.from_array(m, v) self.parent.update(self.message_to_parent, message_to_parent) self.message_to_parent = message_to_parent
def __init__(self, parent: Gaussian, child: Bernoulli): VMPFactor.__init__(self) self._deterministic = False self._observed = False self.parent = parent self.child = child self.message_to_child = Bernoulli.uniform(self.shape) self.message_to_parent = Gaussian.uniform(self.shape)
def to_parent(self, i): message_from_child = self.child / self.message_to_child om, ov = self.parents[1 - i].mean_and_variance cp, cmtp = message_from_child.natural p = cp * (ov + om**2) mtp = cmtp * om message_to_parent = Gaussian(p, mtp) self.parents[i].update(self.message_to_parents[i], message_to_parent) self.message_to_parents[i].set_to(message_to_parent)
def to_parent(self): message_from_children = tuple( c / m for c, m in zip(self.children, self.message_to_children)) p = tuple(c.precision for c in message_from_children) mtp = tuple(c.mean_times_precision for c in message_from_children) p = torch.cat(p, 1) mtp = torch.cat(mtp, 1) message_to_parent = Gaussian(p, mtp) self.parent.update(self.message_to_parent, message_to_parent) self.message_to_parent = message_to_parent
def to_child(self): message_from_parents = self.parents m0, v0 = message_from_parents[0].mean_and_variance m1, v1 = message_from_parents[1].mean_and_variance mean = m0 * m1 var = m0**2 * v1 + m1**2 * v0 + v0 * v1 child = Gaussian.from_array(mean, var) message_from_child = self.child / self.message_to_child message_to_child = child / message_from_child self.child.set_to(child) self.message_to_child = message_to_child
def to_child(self): message_from_parent = self.parent / self.message_to_parent p, mtp = message_from_parent.natural ps = p.split(self._chunks, 1) mtps = mtp.split(self._chunks, 1) message_to_children = tuple( Gaussian(p, mtp) for p, mtp in zip(ps, mtps)) for c, m_prev, m_new in zip(self.children, self.message_to_children, message_to_children): c.update(m_prev, m_new) self.message_to_children = message_to_children
def __init__(self, parents: Tuple[Gaussian], child: Gaussian): VMPFactor.__init__(self) self._deterministic = True self.parents = parents self.child = child self._products = Gaussian.uniform(self.parents[0].shape) self._product = Product(parents, self._products) self._linear = Linear(self._products, child) self._linear.weight.data = torch.ones_like(self._linear.weight.data) self._linear.weight.requires_grad = False self._linear.bias.data = torch.zeros_like(self._linear.bias.data) self._linear.bias.requires_grad = False
def to_child(self): message_from_parents = tuple( p / mtp for p, mtp in zip(self.parents, self.message_to_parents)) m = tuple(mfp.mean for mfp in message_from_parents) v = tuple(mfp.variance for mfp in message_from_parents) mean = torch.cat(m, 1).nansum(1, keepdims=True) var = torch.cat(v, 1).nansum(1, keepdims=True) message_to_child = Gaussian.from_array(mean, var) self.child.update(self.message_to_child, message_to_child) self.message_to_child = message_to_child
def to_child(self): message_from_parent = self.parent.unsqueeze( -1) / self.message_to_parent m, v = message_from_parent.mean_and_variance mean = (m * self.weight.unsqueeze(0)).nansum(1) + self.bias.unsqueeze(0) var = (v * self.weight.unsqueeze(0)**2).nansum(1) message_to_child = Gaussian.from_array(mean, var) self.child.update(self.message_to_child, message_to_child) self.message_to_child = message_to_child
def __init__(self, positions: Gaussian, index: torch.Tensor, X_cts: torch.Tensor, X_bin: torch.Tensor): f.VMPFactor.__init__(self) self._deterministic = False p_cts = X_cts.shape[1] p_bin = X_bin.shape[1] p = p_cts + p_bin n = index.shape[0] K = positions.shape[1] self._positions = Gaussian.uniform((n, K)) self._mean = Gaussian.uniform((n, p)) self._mean_cts = Gaussian.uniform((n, p_cts)) self._mean_bin = Gaussian.uniform((n, p_bin)) self._select = f.Select(positions, index, self._positions) self._linear = f.Linear(self._positions, self._mean) self._split = f.Split(self._mean, (self._mean_cts, self._mean_bin)) self._gaussian = f.GaussianFactor.observed(self._mean_cts, X_cts, torch.zeros(p_cts)) self._logistic = f.Logistic.observed(self._mean_bin, X_bin)
def to_parent(self): m, v = self.parent.mean_and_variance integrals = sigmoid_integrals(m, v, [0, 1]) sd = torch.sqrt(v) exp1 = m * integrals[0] + sd * integrals[1] p = (exp1 - m * integrals[0]) / v mtp = m * p + self.child.proba - integrals[0] p = torch.where(self.child.is_uniform, torch.full_like(p, Gaussian.uniform_precision), p) mtp = torch.where(self.child.is_uniform, torch.full_like(mtp, 0.), mtp) message_to_parent = Gaussian(p, mtp) self.parent.update(self.message_to_parent, message_to_parent) self.message_to_parent = message_to_parent
def __init__(self, positions: Gaussian, heterogeneity: Gaussian, indices: Tuple[torch.Tensor], links: torch.Tensor): f.VMPFactor.__init__(self) self._deterministic = False n = indices[0].shape[0] K = positions.shape[1] self._positions = tuple(Gaussian.uniform((n, K)) for _ in range(2)) self._heterogeneity = tuple(Gaussian.uniform((n, 1)) for _ in range(2)) self._inner_products = Gaussian.uniform((n, 1)) self._logits = Gaussian.uniform((n, 1)) self._select_positions = tuple( f.Select(positions, i, p) for i, p in zip(indices, self._positions)) self._select_heterogeneity = tuple( f.Select(heterogeneity, i, p) for i, p in zip(indices, self._heterogeneity)) self._inner_product = f.InnerProduct(self._positions, self._inner_products) self._sum = f.Sum((self._inner_products, *self._heterogeneity), self._logits) self._logistic = f.Logistic.observed(self._logits, links)
def to_parent(self): message_from_child = self.child / self.message_to_child p, mtp = message_from_child.natural p_sum = torch.zeros(self.parent.shape, device=p.device, dtype=p.dtype) mtp_sum = torch.zeros(self.parent.shape, device=p.device, dtype=p.dtype) # for i in range(self.parent.shape[0]): # p_sum[i, ] = p[self.which[i, ], ].nansum(0) # mtp_sum[i, ] = mtp[self.which[i, ], ].nansum(0) # random updates seems to help a bit? n = self.parent.shape[0] for i in set(torch.randint(0, n, (n // 2, ))): p_sum[i, ] = p[self.which[i, ], ].nansum(0) mtp_sum[i, ] = mtp[self.which[i, ], ].nansum(0) message_to_parent = Gaussian(p, mtp) message_to_parent_sum = Gaussian(p_sum, mtp_sum) self.parent.update(self.message_to_parent_sum, message_to_parent_sum) self.message_to_parent = message_to_parent self.message_to_parent_sum = message_to_parent_sum
def to_parent(self): message_from_child = self.child / self.message_to_child cm, cv = message_from_child.mean_and_variance message_from_parents = tuple( p / mtp for p, mtp in zip(self.parents, self.message_to_parents)) pm = tuple(mfp.mean for mfp in message_from_parents) pv = tuple(mfp.variance for mfp in message_from_parents) mean = torch.cat(pm, 1).nansum(1, keepdims=True) var = torch.cat(pv, 1).nansum(1, keepdims=True) mmtp = tuple(cm - mean + m for m in pm) vmtp = tuple(cv + var - v for v in pv) mtp = tuple(Gaussian.from_array(m, v) for m, v in zip(mmtp, vmtp)) for p, mtp_prev, mtp_new in zip(self.parents, self.message_to_parents, mtp): p.update(mtp_prev, mtp_new) self.message_to_parents = mtp
def to_parent(self): message_from_child = self.child / self.message_to_child cm, cv = message_from_child.mean_and_variance message_from_parent = self.parent.unsqueeze( -1) / self.message_to_parent pm, pv = message_from_parent.mean_and_variance w = self.weight.unsqueeze(0) b = self.bias.unsqueeze(0).unsqueeze(0) pm_sum = (pm * w).nansum(1, keepdim=True) - pm * w pv_sum = (pv * w**2).nansum(1, keepdim=True) - pv * w**2 mean = (cm.unsqueeze(1) - b - pm_sum) / w var = (cv.unsqueeze(1) + pv_sum) / (w**2) message_to_parent = Gaussian.from_array(mean, var) message_to_parent_sum = message_to_parent.product(-1) self.parent.update(self.message_to_parent_sum, message_to_parent_sum) self.message_to_parent = message_to_parent self.message_to_parent_sum = message_to_parent_sum
# ----------------------------------------------------------------------------- # GaussianFactor import torch import numpy as np from NNVI.vmp.gaussian import Gaussian from NNVI.vmp.factors import GaussianFactor # stochastic case shape = (2, 3) parent = Gaussian.from_shape(shape, 0., 1.) child = Gaussian.from_shape(shape, 1., 2.) self = GaussianFactor(parent, child, torch.arange(3).double()) self.forward() self.backward() self self.to_elbo() # prior case shape = (2, 3) child = Gaussian.from_shape(shape, 1., 2.) self = GaussianFactor.prior(child, 0., 1.) self.forward() self.backward() self self.to_elbo() # observed case shape = (2, 3) parent = Gaussian.from_shape(shape, 0., 1.) child = torch.randn(shape)
import torch from NNVI.vmp.gaussian import Gaussian p = torch.ones((2, 3)) * 2. mtp = torch.ones((2, 3)) * 3. self = Gaussian(p, mtp) other = Gaussian(p*0.2, mtp*0.3) other2 = Gaussian(p*0.5, mtp*0.5) self.update(other, other2) self[1].update(other[1], other2[1]) self *= other p = torch.ones((3, )) * 4. mtp = torch.ones((3, )) * 5. self[0, :] = Gaussian(p, mtp) self.entropy() self.negative_entropy() Gaussian.point_mass(p) self.cuda() self.split((1, 2), 1) Gaussian.cat([self, other, other2], 0)
def observed(cls, parent: Gaussian, child: torch.tensor, variance: torch.Tensor): child = Gaussian.observed(child) obj = cls(parent, child, variance) obj._observed = True return obj
def prior(cls, child: Gaussian, mean: float, variance: torch.Tensor): shape = child.shape parent = Gaussian.point_mass(torch.full(shape, mean)) obj = cls(parent, child, variance) obj._prior = True return obj