Example #1
0
    def _message_to_parent(self, index):

        # Compute the message, check plates, apply mask and sum over some plates
        if index >= len(self.parents):
            raise ValueError("Parent index larger than the number of parents")

        # Compute the message and mask
        (m, mask) = self._get_message_and_mask_to_parent(index)
        mask = misc.squeeze(mask)

        # Plates in the mask
        plates_mask = np.shape(mask)

        # The parent we're sending the message to
        parent = self.parents[index]

        # Plates with respect to the parent
        plates_self = self._plates_to_parent(index)

        # Plate multiplier of the parent
        multiplier_parent = self._plates_multiplier_from_parent(index)

        # Check if m is a logpdf function (for black-box variational inference)
        if callable(m):
            def m_function(*args):
                lpdf = m(*args)
                # Log pdf only contains plate axes!
                plates_m = np.shape(lpdf)
                r = (self.broadcasting_multiplier(plates_self,
                                                  plates_m,
                                                  plates_mask,
                                                  parent.plates) *
                     self.broadcasting_multiplier(self.plates_multiplier,
                                                  multiplier_parent))
                axes_msg = misc.axes_to_collapse(plates_m, parent.plates)
                m[i] = misc.sum_multiply(mask_i, m[i], r,
                                         axis=axes_msg,
                                         keepdims=True)

                # Remove leading singular plates if the parent does not have
                # those plate axes.
                m[i] = misc.squeeze_to_dim(m[i], len(shape_parent))

            return m_function
            raise NotImplementedError()

        # Compact the message to a proper shape
        for i in range(len(m)):

            # Empty messages are given as None. We can ignore those.
            if m[i] is not None:

                try:
                    r = self.broadcasting_multiplier(self.plates_multiplier,
                                                     multiplier_parent)
                except:
                    raise ValueError("The plate multipliers are incompatible. "
                                     "This node (%s) has %s and parent[%d] "
                                     "(%s) has %s"
                                     % (self.name,
                                        self.plates_multiplier,
                                        index,
                                        parent.name,
                                        multiplier_parent))

                ndim = len(parent.dims[i])
                # Source and target shapes
                if ndim > 0:
                    dims = misc.broadcasted_shape(np.shape(m[i])[-ndim:],
                                                  parent.dims[i])
                    from_shape = plates_self + dims
                else:
                    from_shape = plates_self
                to_shape = parent.get_shape(i)
                # Add variable axes to the mask
                mask_i = misc.add_trailing_axes(mask, ndim)
                # Apply mask and sum plate axes as necessary (and apply plate
                # multiplier)
                m[i] = r * misc.sum_multiply_to_plates(m[i], mask_i,
                                                       to_plates=to_shape,
                                                       from_plates=from_shape,
                                                       ndim=0)

        return m
Example #2
0
    def _message_to_parent(self, index):

        # Compute the message, check plates, apply mask and sum over some plates
        if index >= len(self.parents):
            raise ValueError("Parent index larger than the number of parents")

        # Compute the message and mask
        (m, mask) = self._get_message_and_mask_to_parent(index)
        mask = misc.squeeze(mask)

        # Plates in the mask
        plates_mask = np.shape(mask)

        # The parent we're sending the message to
        parent = self.parents[index]

        # Plates with respect to the parent
        plates_self = self._plates_to_parent(index)

        # Plate multiplier of the parent
        multiplier_parent = self._plates_multiplier_from_parent(index)

        # Check if m is a logpdf function (for black-box variational inference)
        if callable(m):

            def m_function(*args):
                lpdf = m(*args)
                # Log pdf only contains plate axes!
                plates_m = np.shape(lpdf)
                r = (self.broadcasting_multiplier(plates_self, plates_m,
                                                  plates_mask, parent.plates) *
                     self.broadcasting_multiplier(self.plates_multiplier,
                                                  multiplier_parent))
                axes_msg = misc.axes_to_collapse(plates_m, parent.plates)
                m[i] = misc.sum_multiply(mask_i,
                                         m[i],
                                         r,
                                         axis=axes_msg,
                                         keepdims=True)

                # Remove leading singular plates if the parent does not have
                # those plate axes.
                m[i] = misc.squeeze_to_dim(m[i], len(shape_parent))

            return m_function
            raise NotImplementedError()

        # Compact the message to a proper shape
        for i in range(len(m)):

            # Empty messages are given as None. We can ignore those.
            if m[i] is not None:

                try:
                    r = self.broadcasting_multiplier(self.plates_multiplier,
                                                     multiplier_parent)
                except:
                    raise ValueError("The plate multipliers are incompatible. "
                                     "This node (%s) has %s and parent[%d] "
                                     "(%s) has %s" %
                                     (self.name, self.plates_multiplier, index,
                                      parent.name, multiplier_parent))

                ndim = len(parent.dims[i])
                # Source and target shapes
                if ndim > 0:
                    dims = misc.broadcasted_shape(
                        np.shape(m[i])[-ndim:], parent.dims[i])
                    from_shape = plates_self + dims
                else:
                    from_shape = plates_self
                to_shape = parent.get_shape(i)
                # Add variable axes to the mask
                mask_i = misc.add_trailing_axes(mask, ndim)
                # Apply mask and sum plate axes as necessary (and apply plate
                # multiplier)
                m[i] = r * misc.sum_multiply_to_plates(m[i],
                                                       mask_i,
                                                       to_plates=to_shape,
                                                       from_plates=from_shape,
                                                       ndim=0)

        return m