def to_parent(self, i): message_from_child = self.child / self.message_to_child om, ov = self.parents[1 - i].mean_and_variance cp, cmtp = message_from_child.natural p = cp * (ov + om**2) mtp = cmtp * om message_to_parent = Gaussian(p, mtp) self.parents[i].update(self.message_to_parents[i], message_to_parent) self.message_to_parents[i].set_to(message_to_parent)
def to_parent(self): message_from_children = tuple( c / m for c, m in zip(self.children, self.message_to_children)) p = tuple(c.precision for c in message_from_children) mtp = tuple(c.mean_times_precision for c in message_from_children) p = torch.cat(p, 1) mtp = torch.cat(mtp, 1) message_to_parent = Gaussian(p, mtp) self.parent.update(self.message_to_parent, message_to_parent) self.message_to_parent = message_to_parent
def to_child(self): message_from_parent = self.parent / self.message_to_parent p, mtp = message_from_parent.natural ps = p.split(self._chunks, 1) mtps = mtp.split(self._chunks, 1) message_to_children = tuple( Gaussian(p, mtp) for p, mtp in zip(ps, mtps)) for c, m_prev, m_new in zip(self.children, self.message_to_children, message_to_children): c.update(m_prev, m_new) self.message_to_children = message_to_children
def to_parent(self): m, v = self.parent.mean_and_variance integrals = sigmoid_integrals(m, v, [0, 1]) sd = torch.sqrt(v) exp1 = m * integrals[0] + sd * integrals[1] p = (exp1 - m * integrals[0]) / v mtp = m * p + self.child.proba - integrals[0] p = torch.where(self.child.is_uniform, torch.full_like(p, Gaussian.uniform_precision), p) mtp = torch.where(self.child.is_uniform, torch.full_like(mtp, 0.), mtp) message_to_parent = Gaussian(p, mtp) self.parent.update(self.message_to_parent, message_to_parent) self.message_to_parent = message_to_parent
def to_parent(self): message_from_child = self.child / self.message_to_child p, mtp = message_from_child.natural p_sum = torch.zeros(self.parent.shape, device=p.device, dtype=p.dtype) mtp_sum = torch.zeros(self.parent.shape, device=p.device, dtype=p.dtype) # for i in range(self.parent.shape[0]): # p_sum[i, ] = p[self.which[i, ], ].nansum(0) # mtp_sum[i, ] = mtp[self.which[i, ], ].nansum(0) # random updates seems to help a bit? n = self.parent.shape[0] for i in set(torch.randint(0, n, (n // 2, ))): p_sum[i, ] = p[self.which[i, ], ].nansum(0) mtp_sum[i, ] = mtp[self.which[i, ], ].nansum(0) message_to_parent = Gaussian(p, mtp) message_to_parent_sum = Gaussian(p_sum, mtp_sum) self.parent.update(self.message_to_parent_sum, message_to_parent_sum) self.message_to_parent = message_to_parent self.message_to_parent_sum = message_to_parent_sum
import torch from NNVI.vmp.gaussian import Gaussian p = torch.ones((2, 3)) * 2. mtp = torch.ones((2, 3)) * 3. self = Gaussian(p, mtp) other = Gaussian(p*0.2, mtp*0.3) other2 = Gaussian(p*0.5, mtp*0.5) self.update(other, other2) self[1].update(other[1], other2[1]) self *= other p = torch.ones((3, )) * 4. mtp = torch.ones((3, )) * 5. self[0, :] = Gaussian(p, mtp) self.entropy() self.negative_entropy() Gaussian.point_mass(p) self.cuda() self.split((1, 2), 1) Gaussian.cat([self, other, other2], 0)