Ejemplo n.º 1
0
 def to_parent(self, i):
     message_from_child = self.child / self.message_to_child
     om, ov = self.parents[1 - i].mean_and_variance
     cp, cmtp = message_from_child.natural
     p = cp * (ov + om**2)
     mtp = cmtp * om
     message_to_parent = Gaussian(p, mtp)
     self.parents[i].update(self.message_to_parents[i], message_to_parent)
     self.message_to_parents[i].set_to(message_to_parent)
Ejemplo n.º 2
0
 def to_parent(self):
     message_from_children = tuple(
         c / m for c, m in zip(self.children, self.message_to_children))
     p = tuple(c.precision for c in message_from_children)
     mtp = tuple(c.mean_times_precision for c in message_from_children)
     p = torch.cat(p, 1)
     mtp = torch.cat(mtp, 1)
     message_to_parent = Gaussian(p, mtp)
     self.parent.update(self.message_to_parent, message_to_parent)
     self.message_to_parent = message_to_parent
Ejemplo n.º 3
0
 def to_child(self):
     message_from_parent = self.parent / self.message_to_parent
     p, mtp = message_from_parent.natural
     ps = p.split(self._chunks, 1)
     mtps = mtp.split(self._chunks, 1)
     message_to_children = tuple(
         Gaussian(p, mtp) for p, mtp in zip(ps, mtps))
     for c, m_prev, m_new in zip(self.children, self.message_to_children,
                                 message_to_children):
         c.update(m_prev, m_new)
     self.message_to_children = message_to_children
Ejemplo n.º 4
0
 def to_parent(self):
     m, v = self.parent.mean_and_variance
     integrals = sigmoid_integrals(m, v, [0, 1])
     sd = torch.sqrt(v)
     exp1 = m * integrals[0] + sd * integrals[1]
     p = (exp1 - m * integrals[0]) / v
     mtp = m * p + self.child.proba - integrals[0]
     p = torch.where(self.child.is_uniform,
                     torch.full_like(p, Gaussian.uniform_precision), p)
     mtp = torch.where(self.child.is_uniform, torch.full_like(mtp, 0.), mtp)
     message_to_parent = Gaussian(p, mtp)
     self.parent.update(self.message_to_parent, message_to_parent)
     self.message_to_parent = message_to_parent
Ejemplo n.º 5
0
    def to_parent(self):
        message_from_child = self.child / self.message_to_child
        p, mtp = message_from_child.natural
        p_sum = torch.zeros(self.parent.shape, device=p.device, dtype=p.dtype)
        mtp_sum = torch.zeros(self.parent.shape,
                              device=p.device,
                              dtype=p.dtype)

        # for i in range(self.parent.shape[0]):
        #     p_sum[i, ] = p[self.which[i, ], ].nansum(0)
        #     mtp_sum[i, ] = mtp[self.which[i, ], ].nansum(0)

        # random updates seems to help a bit?
        n = self.parent.shape[0]
        for i in set(torch.randint(0, n, (n // 2, ))):
            p_sum[i, ] = p[self.which[i, ], ].nansum(0)
            mtp_sum[i, ] = mtp[self.which[i, ], ].nansum(0)

        message_to_parent = Gaussian(p, mtp)
        message_to_parent_sum = Gaussian(p_sum, mtp_sum)
        self.parent.update(self.message_to_parent_sum, message_to_parent_sum)
        self.message_to_parent = message_to_parent
        self.message_to_parent_sum = message_to_parent_sum
Ejemplo n.º 6
0
import torch
from NNVI.vmp.gaussian import Gaussian

p = torch.ones((2, 3)) * 2.
mtp = torch.ones((2, 3)) * 3.


self = Gaussian(p, mtp)
other = Gaussian(p*0.2, mtp*0.3)
other2 = Gaussian(p*0.5, mtp*0.5)

self.update(other, other2)

self[1].update(other[1], other2[1])

self *= other

p = torch.ones((3, )) * 4.
mtp = torch.ones((3, )) * 5.

self[0, :] = Gaussian(p, mtp)
self.entropy()
self.negative_entropy()
Gaussian.point_mass(p)

self.cuda()

self.split((1, 2), 1)

Gaussian.cat([self, other, other2], 0)