def to_parent(self): message_from_child = self.child / self.message_to_child m = message_from_child.mean v = message_from_child.variance + self.log_var.exp().unsqueeze(0) message_to_parent = Gaussian.from_array(m, v) self.parent.update(self.message_to_parent, message_to_parent) self.message_to_parent = message_to_parent
def to_child(self): message_from_parent = self.parent / self.message_to_parent m = message_from_parent.mean v = message_from_parent.variance + self.log_var.exp().unsqueeze(0) message_to_child = Gaussian.from_array(m, v) self.child.update(self.message_to_child, message_to_child) self.message_to_child = message_to_child
def to_child(self): message_from_parents = self.parents m0, v0 = message_from_parents[0].mean_and_variance m1, v1 = message_from_parents[1].mean_and_variance mean = m0 * m1 var = m0**2 * v1 + m1**2 * v0 + v0 * v1 child = Gaussian.from_array(mean, var) message_from_child = self.child / self.message_to_child message_to_child = child / message_from_child self.child.set_to(child) self.message_to_child = message_to_child
def to_child(self): message_from_parents = tuple( p / mtp for p, mtp in zip(self.parents, self.message_to_parents)) m = tuple(mfp.mean for mfp in message_from_parents) v = tuple(mfp.variance for mfp in message_from_parents) mean = torch.cat(m, 1).nansum(1, keepdims=True) var = torch.cat(v, 1).nansum(1, keepdims=True) message_to_child = Gaussian.from_array(mean, var) self.child.update(self.message_to_child, message_to_child) self.message_to_child = message_to_child
def to_child(self): message_from_parent = self.parent.unsqueeze( -1) / self.message_to_parent m, v = message_from_parent.mean_and_variance mean = (m * self.weight.unsqueeze(0)).nansum(1) + self.bias.unsqueeze(0) var = (v * self.weight.unsqueeze(0)**2).nansum(1) message_to_child = Gaussian.from_array(mean, var) self.child.update(self.message_to_child, message_to_child) self.message_to_child = message_to_child
def to_parent(self): message_from_child = self.child / self.message_to_child cm, cv = message_from_child.mean_and_variance message_from_parents = tuple( p / mtp for p, mtp in zip(self.parents, self.message_to_parents)) pm = tuple(mfp.mean for mfp in message_from_parents) pv = tuple(mfp.variance for mfp in message_from_parents) mean = torch.cat(pm, 1).nansum(1, keepdims=True) var = torch.cat(pv, 1).nansum(1, keepdims=True) mmtp = tuple(cm - mean + m for m in pm) vmtp = tuple(cv + var - v for v in pv) mtp = tuple(Gaussian.from_array(m, v) for m, v in zip(mmtp, vmtp)) for p, mtp_prev, mtp_new in zip(self.parents, self.message_to_parents, mtp): p.update(mtp_prev, mtp_new) self.message_to_parents = mtp
def to_parent(self): message_from_child = self.child / self.message_to_child cm, cv = message_from_child.mean_and_variance message_from_parent = self.parent.unsqueeze( -1) / self.message_to_parent pm, pv = message_from_parent.mean_and_variance w = self.weight.unsqueeze(0) b = self.bias.unsqueeze(0).unsqueeze(0) pm_sum = (pm * w).nansum(1, keepdim=True) - pm * w pv_sum = (pv * w**2).nansum(1, keepdim=True) - pv * w**2 mean = (cm.unsqueeze(1) - b - pm_sum) / w var = (cv.unsqueeze(1) + pv_sum) / (w**2) message_to_parent = Gaussian.from_array(mean, var) message_to_parent_sum = message_to_parent.product(-1) self.parent.update(self.message_to_parent_sum, message_to_parent_sum) self.message_to_parent = message_to_parent self.message_to_parent_sum = message_to_parent_sum
# ----------------------------------------------------------------------------- # Logistic import torch import numpy as np from NNVI.vmp.utils import sigmoid_integrals from NNVI.vmp.bernoulli import Bernoulli from NNVI.vmp.gaussian import Gaussian from NNVI.vmp.factors import Logistic shape = (3, 2) mean = torch.randn(shape) variance = torch.rand(shape) # sigmoid_integrals(mean, variance, [0,1,2]) parent = Gaussian.from_array(mean, variance) child = (torch.rand(shape) > 0.5).float() child[0, 0] = np.nan child = Bernoulli.observed(child) self = Logistic(parent, child) self self.forward() self.backward() self self.to_elbo() # ----------------------------------------------------------------------------- # Linear