コード例 #1
0
ファイル: test_node.py プロジェクト: buptpriswang/bayespy
    def check_message_to_parent(self, plates_child, plates_message,
                                plates_mask, plates_parent, dims=(2,)):

        # Dummy message
        msg = np.random.randn(*(plates_message+dims))
        # Mask with every other True and every other False
        mask = np.mod(np.arange(np.prod(plates_mask)).reshape(plates_mask),
                      2) == 0

        # Set up the dummy model
        class Dummy(Node):
            _moments = Moments()
            def __init__(self, *args, **kwargs):
                self._parent_moments = len(args)*(Moments(),)
                super().__init__(*args, **kwargs)
            def _get_message_and_mask_to_parent(self, index):
                return ([msg], mask)
            def _get_id_list(self):
                return []
        parent = Dummy(dims=[dims], plates=plates_parent)
        child = Dummy(parent, dims=[dims], plates=plates_child)

        m = child._message_to_parent(0)[0] * np.ones(plates_parent+dims)

        # Brute-force computation of the message without too much checking
        m_true = msg * misc.squeeze(mask[...,np.newaxis]) * np.ones(plates_child+dims)
        for ind in range(len(plates_child)):
            axis = -ind - 2
            if ind >= len(plates_parent):
                m_true = np.sum(m_true, axis=axis, keepdims=False)
            elif plates_parent[-ind-1] == 1:
                m_true = np.sum(m_true, axis=axis, keepdims=True)

        testing.assert_allclose(m, m_true,
                                err_msg="Incorrect message.")
コード例 #2
0
ファイル: test_node.py プロジェクト: vcsrc/bayespy
    def check_message_to_parent(self, plates_child, plates_message,
                                plates_mask, plates_parent, dims=(2,)):

        # Dummy message
        msg = np.random.randn(*(plates_message+dims))
        # Mask with every other True and every other False
        mask = np.mod(np.arange(np.prod(plates_mask)).reshape(plates_mask),
                      2) == 0

        # Set up the dummy model
        class Dummy(Node):
            _moments = Moments()
            def __init__(self, *args, **kwargs):
                self._parent_moments = len(args)*(Moments(),)
                super().__init__(*args, **kwargs)
            def _get_message_and_mask_to_parent(self, index):
                return ([msg], mask)
            def _get_id_list(self):
                return []
        parent = Dummy(dims=[dims], plates=plates_parent)
        child = Dummy(parent, dims=[dims], plates=plates_child)

        m = child._message_to_parent(0)[0] * np.ones(plates_parent+dims)

        # Brute-force computation of the message without too much checking
        m_true = msg * misc.squeeze(mask[...,np.newaxis]) * np.ones(plates_child+dims)
        for ind in range(len(plates_child)):
            axis = -ind - 2
            if ind >= len(plates_parent):
                m_true = np.sum(m_true, axis=axis, keepdims=False)
            elif plates_parent[-ind-1] == 1:
                m_true = np.sum(m_true, axis=axis, keepdims=True)

        testing.assert_allclose(m, m_true,
                                err_msg="Incorrect message.")
コード例 #3
0
ファイル: node.py プロジェクト: chagge/bayespy
    def _message_to_parent(self, index):

        # Compute the message, check plates, apply mask and sum over some plates
        if index >= len(self.parents):
            raise ValueError("Parent index larger than the number of parents")

        # Compute the message and mask
        (m, mask) = self._get_message_and_mask_to_parent(index)
        mask = misc.squeeze(mask)

        # Plates in the mask
        plates_mask = np.shape(mask)

        # The parent we're sending the message to
        parent = self.parents[index]

        # Plates with respect to the parent
        plates_self = self._plates_to_parent(index)

        # Plate multiplier of the parent
        multiplier_parent = self._plates_multiplier_from_parent(index)

        # Check if m is a logpdf function (for black-box variational inference)
        if callable(m):
            def m_function(*args):
                lpdf = m(*args)
                # Log pdf only contains plate axes!
                plates_m = np.shape(lpdf)
                r = (self.broadcasting_multiplier(plates_self,
                                                  plates_m,
                                                  plates_mask,
                                                  parent.plates) *
                     self.broadcasting_multiplier(self.plates_multiplier,
                                                  multiplier_parent))
                axes_msg = misc.axes_to_collapse(plates_m, parent.plates)
                m[i] = misc.sum_multiply(mask_i, m[i], r,
                                         axis=axes_msg,
                                         keepdims=True)

                # Remove leading singular plates if the parent does not have
                # those plate axes.
                m[i] = misc.squeeze_to_dim(m[i], len(shape_parent))

            return m_function
            raise NotImplementedError()

        # Compact the message to a proper shape
        for i in range(len(m)):

            # Empty messages are given as None. We can ignore those.
            if m[i] is not None:

                try:
                    r = self.broadcasting_multiplier(self.plates_multiplier,
                                                     multiplier_parent)
                except:
                    raise ValueError("The plate multipliers are incompatible. "
                                     "This node (%s) has %s and parent[%d] "
                                     "(%s) has %s"
                                     % (self.name,
                                        self.plates_multiplier,
                                        index,
                                        parent.name,
                                        multiplier_parent))

                ndim = len(parent.dims[i])
                # Source and target shapes
                if ndim > 0:
                    dims = misc.broadcasted_shape(np.shape(m[i])[-ndim:],
                                                  parent.dims[i])
                    from_shape = plates_self + dims
                else:
                    from_shape = plates_self
                to_shape = parent.get_shape(i)
                # Add variable axes to the mask
                mask_i = misc.add_trailing_axes(mask, ndim)
                # Apply mask and sum plate axes as necessary (and apply plate
                # multiplier)
                m[i] = r * misc.sum_multiply_to_plates(m[i], mask_i,
                                                       to_plates=to_shape,
                                                       from_plates=from_shape,
                                                       ndim=0)

        return m
コード例 #4
0
ファイル: node.py プロジェクト: buptpriswang/bayespy
    def _message_to_parent(self, index):

        # Compute the message, check plates, apply mask and sum over some plates
        if index >= len(self.parents):
            raise ValueError("Parent index larger than the number of parents")

        # Compute the message and mask
        (m, mask) = self._get_message_and_mask_to_parent(index)
        mask = misc.squeeze(mask)

        # Plates in the mask
        plates_mask = np.shape(mask)

        # The parent we're sending the message to
        parent = self.parents[index]

        # Compact the message to a proper shape
        for i in range(len(m)):

            # Empty messages are given as None. We can ignore those.
            if m[i] is not None:

                # Plates in the message
                shape_m = np.shape(m[i])
                dim_parent = len(parent.dims[i])
                if dim_parent > 0:
                    plates_m = shape_m[:-dim_parent]
                else:
                    plates_m = shape_m

                # Compute the multiplier (multiply by the number of plates for
                # which the message, the mask and the parent have single
                # plates).  Such a plate is meant to be broadcasted but because
                # the parent has singular plate axis, it won't broadcast (and
                # sum over it), so we need to multiply it.
                plates_self = self._plates_to_parent(index)
                try:
                    r = self._plate_multiplier(plates_self, 
                                               plates_m,
                                               plates_mask,
                                               parent.plates)
                except ValueError:
                    raise ValueError("The plates of the message, the mask and "
                                     "parent[%d] node (%s) are not a "
                                     "broadcastable subset of the plates of "
                                     "this node (%s).  The message has shape "
                                     "%s, meaning plates %s. The mask has "
                                     "plates %s. This node has plates %s with "
                                     "respect to the parent[%d], which has "
                                     "plates %s."
                                     % (index,
                                        parent.name,
                                        self.name,
                                        np.shape(m[i]), 
                                        plates_m, 
                                        plates_mask,
                                        plates_self,
                                        index, 
                                        parent.plates))

                # Add variable axes to the mask
                shape_mask = np.shape(mask) + (1,) * len(parent.dims[i])
                mask_i = np.reshape(mask, shape_mask)

                # Sum over plates that are not in the message nor in the parent
                shape_parent = parent.get_shape(i)
                shape_msg = misc.broadcasted_shape(shape_m, shape_parent)
                axes_mask = misc.axes_to_collapse(shape_mask, shape_msg)
                mask_i = np.sum(mask_i, axis=axes_mask, keepdims=True)

                # Compute the masked message and sum over the plates that the
                # parent does not have.
                axes_msg = misc.axes_to_collapse(shape_msg, shape_parent)
                m[i] = misc.sum_multiply(mask_i, m[i], r, 
                                         axis=axes_msg, 
                                         keepdims=True)

                # Remove leading singular plates if the parent does not have
                # those plate axes.
                m[i] = misc.squeeze_to_dim(m[i], len(shape_parent))

        return m
コード例 #5
0
ファイル: node.py プロジェクト: pfjob09/bayespy
    def _message_to_parent(self, index):

        # Compute the message, check plates, apply mask and sum over some plates
        if index >= len(self.parents):
            raise ValueError("Parent index larger than the number of parents")

        # Compute the message and mask
        (m, mask) = self._get_message_and_mask_to_parent(index)
        mask = misc.squeeze(mask)

        # Plates in the mask
        plates_mask = np.shape(mask)

        # The parent we're sending the message to
        parent = self.parents[index]

        # Plates with respect to the parent
        plates_self = self._plates_to_parent(index)

        # Plate multiplier of the parent
        multiplier_parent = self._plates_multiplier_from_parent(index)

        # Check if m is a logpdf function (for black-box variational inference)
        if callable(m):

            def m_function(*args):
                lpdf = m(*args)
                # Log pdf only contains plate axes!
                plates_m = np.shape(lpdf)
                r = (self.broadcasting_multiplier(plates_self, plates_m,
                                                  plates_mask, parent.plates) *
                     self.broadcasting_multiplier(self.plates_multiplier,
                                                  multiplier_parent))
                axes_msg = misc.axes_to_collapse(plates_m, parent.plates)
                m[i] = misc.sum_multiply(mask_i,
                                         m[i],
                                         r,
                                         axis=axes_msg,
                                         keepdims=True)

                # Remove leading singular plates if the parent does not have
                # those plate axes.
                m[i] = misc.squeeze_to_dim(m[i], len(shape_parent))

            return m_function
            raise NotImplementedError()

        # Compact the message to a proper shape
        for i in range(len(m)):

            # Empty messages are given as None. We can ignore those.
            if m[i] is not None:

                try:
                    r = self.broadcasting_multiplier(self.plates_multiplier,
                                                     multiplier_parent)
                except:
                    raise ValueError("The plate multipliers are incompatible. "
                                     "This node (%s) has %s and parent[%d] "
                                     "(%s) has %s" %
                                     (self.name, self.plates_multiplier, index,
                                      parent.name, multiplier_parent))

                ndim = len(parent.dims[i])
                # Source and target shapes
                if ndim > 0:
                    dims = misc.broadcasted_shape(
                        np.shape(m[i])[-ndim:], parent.dims[i])
                    from_shape = plates_self + dims
                else:
                    from_shape = plates_self
                to_shape = parent.get_shape(i)
                # Add variable axes to the mask
                mask_i = misc.add_trailing_axes(mask, ndim)
                # Apply mask and sum plate axes as necessary (and apply plate
                # multiplier)
                m[i] = r * misc.sum_multiply_to_plates(m[i],
                                                       mask_i,
                                                       to_plates=to_shape,
                                                       from_plates=from_shape,
                                                       ndim=0)

        return m