Ejemplo n.º 1
0
    def get_message(self, index, u_parents):
        
        (m, mask) = self.message_from_children()

        parent = self.parents[index]

        # Compute both messages
        for i in range(2):

            # Add extra axes to the message from children
            #m_shape = np.shape(m[i]) + (1,) * (i+1)
            #m[i] = np.reshape(m[i], m_shape)

            # Put masked elements to zero
            np.copyto(m[i], 0, where=np.logical_not(mask))
                
            # Add extra axes to the mask from children
            #mask_shape = np.shape(mask) + (1,) * (i+1)
            #mask_i = np.reshape(mask, mask_shape)

            #mask_i = mask
            m[i] = utils.add_trailing_axes(m[i], i+1)
            #for k in range(i+1):
                #m[i] = np.expand_dims(m[i], axis=-1)
                #mask_i = np.expand_dims(mask_i, axis=-1)

            # List of elements to multiply together
            A = [m[i]]
            for k in range(len(u_parents)):
                if k != index:
                    A.append(u_parents[k][i])

            # Find out which axes are summed over. Also, 
            full_shape = utils.broadcasted_shape_from_arrays(*A)
            axes = utils.axes_to_collapse(full_shape, parent.get_shape(i))
            # Compute the multiplier for cancelling the
            # plate-multiplier.  Because we are summing over the
            # dimensions already in this function (for efficiency), we
            # need to cancel the effect of the plate-multiplier
            # applied in the message_to_parent function.
            r = 1
            for j in axes:
                r *= full_shape[j]

            # Compute dot product (and cancel plate-multiplier)
            m[i] = utils.sum_product(*A, axes_to_sum=axes, keepdims=True) / r

        # Compute the mask
        s = utils.axes_to_collapse(np.shape(mask), parent.plates)
        mask = np.any(mask, axis=s, keepdims=True)
        mask = utils.squeeze_to_dim(mask, len(parent.plates))

        return (m, mask)
Ejemplo n.º 2
0
    def get_message(self, index, u_parents):

        (m, mask) = self.message_from_children()

        parent = self.parents[index]

        # Compute both messages
        for i in range(2):

            # Add extra axes to the message from children
            #m_shape = np.shape(m[i]) + (1,) * (i+1)
            #m[i] = np.reshape(m[i], m_shape)

            # Put masked elements to zero
            np.copyto(m[i], 0, where=np.logical_not(mask))

            # Add extra axes to the mask from children
            #mask_shape = np.shape(mask) + (1,) * (i+1)
            #mask_i = np.reshape(mask, mask_shape)

            #mask_i = mask
            m[i] = utils.add_trailing_axes(m[i], i + 1)
            #for k in range(i+1):
            #m[i] = np.expand_dims(m[i], axis=-1)
            #mask_i = np.expand_dims(mask_i, axis=-1)

            # List of elements to multiply together
            A = [m[i]]
            for k in range(len(u_parents)):
                if k != index:
                    A.append(u_parents[k][i])

            # Find out which axes are summed over. Also,
            full_shape = utils.broadcasted_shape_from_arrays(*A)
            axes = utils.axes_to_collapse(full_shape, parent.get_shape(i))
            # Compute the multiplier for cancelling the
            # plate-multiplier.  Because we are summing over the
            # dimensions already in this function (for efficiency), we
            # need to cancel the effect of the plate-multiplier
            # applied in the message_to_parent function.
            r = 1
            for j in axes:
                r *= full_shape[j]

            # Compute dot product (and cancel plate-multiplier)
            m[i] = utils.sum_product(*A, axes_to_sum=axes, keepdims=True) / r

        # Compute the mask
        s = utils.axes_to_collapse(np.shape(mask), parent.plates)
        mask = np.any(mask, axis=s, keepdims=True)
        mask = utils.squeeze_to_dim(mask, len(parent.plates))

        return (m, mask)
Ejemplo n.º 3
0
    def OLD_get_message(self, index, u_parents):
        
        (m, mask) = self.message_from_children()

        parent = self.parents[index]

        # Compute both messages
        for i in range(2):

            # Add extra axes to the message from children
            #m_shape = np.shape(m[i]) + (1,) * (i+1)
            #m[i] = np.reshape(m[i], m_shape)

            # Add extra axes to the mask from children
            mask_shape = np.shape(mask) + (1,) * (i+1)
            mask_i = np.reshape(mask, mask_shape)

            mask_i = mask
            for k in range(i+1):
                m[i] = np.expand_dims(m[i], axis=-1)
                mask_i = np.expand_dims(mask_i, axis=-1)

            # List of elements to multiply together
            A = [m[i], mask_i]
            for k in range(len(u_parents)):
                if k != index:
                    A.append(u_parents[k][i])

            # Find out which axes are summed over. Also, because
            # we are summing over the dimensions already in this
            # function (for efficiency), we need to cancel the
            # effect of the plate-multiplier applied in the
            # message_to_parent function.
            full_shape = utils.broadcasted_shape_from_arrays(*A)
            axes = utils.axes_to_collapse(full_shape, parent.get_shape(i))
            r = 1
            for j in axes:
                r *= full_shape[j]

            # Compute dot product
            m[i] = utils.sum_product(*A, axes_to_sum=axes, keepdims=True) / r

        # Compute the mask
        s = utils.axes_to_collapse(np.shape(mask), parent.plates)
        mask = np.any(mask, axis=s, keepdims=True)
        mask = utils.squeeze_to_dim(mask, len(parent.plates))

        return (m, mask)
Ejemplo n.º 4
0
    def OLD_get_message(self, index, u_parents):

        (m, mask) = self.message_from_children()

        parent = self.parents[index]

        # Compute both messages
        for i in range(2):

            # Add extra axes to the message from children
            #m_shape = np.shape(m[i]) + (1,) * (i+1)
            #m[i] = np.reshape(m[i], m_shape)

            # Add extra axes to the mask from children
            mask_shape = np.shape(mask) + (1, ) * (i + 1)
            mask_i = np.reshape(mask, mask_shape)

            mask_i = mask
            for k in range(i + 1):
                m[i] = np.expand_dims(m[i], axis=-1)
                mask_i = np.expand_dims(mask_i, axis=-1)

            # List of elements to multiply together
            A = [m[i], mask_i]
            for k in range(len(u_parents)):
                if k != index:
                    A.append(u_parents[k][i])

            # Find out which axes are summed over. Also, because
            # we are summing over the dimensions already in this
            # function (for efficiency), we need to cancel the
            # effect of the plate-multiplier applied in the
            # message_to_parent function.
            full_shape = utils.broadcasted_shape_from_arrays(*A)
            axes = utils.axes_to_collapse(full_shape, parent.get_shape(i))
            r = 1
            for j in axes:
                r *= full_shape[j]

            # Compute dot product
            m[i] = utils.sum_product(*A, axes_to_sum=axes, keepdims=True) / r

        # Compute the mask
        s = utils.axes_to_collapse(np.shape(mask), parent.plates)
        mask = np.any(mask, axis=s, keepdims=True)
        mask = utils.squeeze_to_dim(mask, len(parent.plates))

        return (m, mask)
Ejemplo n.º 5
0
def _mask_sum(plates_parent, mask):

    # Compute the mask
    axes = utils.axes_to_collapse(np.shape(mask), plates_parent)
    mask = np.any(mask, axis=axes, keepdims=True)
    mask = utils.squeeze_to_dim(mask, len(plates_parent))
    return mask
Ejemplo n.º 6
0
def _message_sum_multiply(plates_parent, dims_parent, *arrays):
    """
    Compute message to parent and sum over plates.
    """
    # The shape of the full message
    shapes = [np.shape(array) for array in arrays]
    shape_full = utils.broadcasted_shape(*shapes)
    # Find axes that should be summed
    shape_parent = plates_parent + dims_parent
    sum_axes = utils.axes_to_collapse(shape_full, shape_parent)
    # Compute the multiplier for cancelling the
    # plate-multiplier.  Because we are summing over the
    # dimensions already in this function (for efficiency), we
    # need to cancel the effect of the plate-multiplier
    # applied in the message_to_parent function.
    r = 1
    for j in sum_axes:
        if j >= 0 and j < len(plates_parent):
            r *= shape_full[j]
        elif j < 0 and j < -len(dims_parent):
            r *= shape_full[j]
    # Compute the sum-product
    m = utils.sum_multiply(*arrays, axis=sum_axes, sumaxis=True,
                           keepdims=True) / r
    # Remove extra axes
    m = utils.squeeze_to_dim(m, len(shape_parent))
    return m
Ejemplo n.º 7
0
    def _compute_message_and_mask_to_parent(self, index, m, *u_parents):

        # Normally we don't need to care about masks when computing the
        # message. However, in this node we want to avoid computing huge message
        # arrays so we sum some axis already here. Thus, we need to apply the
        # mask

        mask = self.mask
        parent = self.parents[index]

        # Compute both messages
        for i in range(2):

            # Add extra axes to the message from children
            m[i] = utils.add_trailing_axes(m[i], i + 1)

            # List of elements to multiply together
            A = [m[i]]
            for k in range(len(u_parents)):
                if k != index:
                    A.append(u_parents[k][i])

            # Compute the sum over some axes already here in order to avoid huge
            # message matrices.
            m[i] = _message_sum_multiply(parent.plates, parent.dims[i], *A)

        # Compute the mask
        s = utils.axes_to_collapse(np.shape(mask), parent.plates)
        mask = np.any(mask, axis=s, keepdims=True)
        mask = utils.squeeze_to_dim(mask, len(parent.plates))

        return (m, mask)
Ejemplo n.º 8
0
    def _mask_to_parent(self, index):
        """
        Get the mask with respect to parent[index].

        The mask tells which plate connections are active. The mask is "summed"
        (logical or) and reshaped into the plate shape of the parent. Thus, it
        can't be used for masking messages, because some plates have been summed
        already. This method is used for propagating the mask to parents.
        """
        mask = self._compute_mask_to_parent(index, self.mask)

        # Check the shape of the mask
        plates_to_parent = self._plates_to_parent(index)
        if not utils.is_shape_subset(np.shape(mask), plates_to_parent):
            raise ValueError("In node %s, the mask being sent to "
                             "parent[%d] (%s) has invalid shape: The shape of "
                             "the mask %s is not a sub-shape of the plates of "
                             "the node with respect to the parent %s. It could "
                             "be that this node (%s) is manipulating plates "
                             "but has not overwritten the method "
                             "_compute_mask_to_parent."
                             % (self.name,
                                index,
                                self.parents[index].name,
                                np.shape(mask),
                                plates_to_parent,
                                self.__class__.__name__))

        # "Sum" (i.e., logical or) over the plates that have unit length in 
        # the parent node.
        parent_plates = self.parents[index].plates
        s = utils.axes_to_collapse(np.shape(mask), parent_plates)
        mask = np.any(mask, axis=s, keepdims=True)
        mask = utils.squeeze_to_dim(mask, len(parent_plates))
        return mask
Ejemplo n.º 9
0
def message_sum_multiply(plates_parent, dims_parent, *arrays):
    """
    Compute message to parent and sum over plates.

    Divide by the plate multiplier.
    """
    # The shape of the full message
    shapes = [np.shape(array) for array in arrays]
    shape_full = utils.broadcasted_shape(*shapes)
    # Find axes that should be summed
    shape_parent = plates_parent + dims_parent
    sum_axes = utils.axes_to_collapse(shape_full, shape_parent)
    # Compute the multiplier for cancelling the
    # plate-multiplier.  Because we are summing over the
    # dimensions already in this function (for efficiency), we
    # need to cancel the effect of the plate-multiplier
    # applied in the message_to_parent function.
    r = 1
    for j in sum_axes:
        if j >= 0 and j < len(plates_parent):
            r *= shape_full[j]
        elif j < 0 and j < -len(dims_parent):
            r *= shape_full[j]
    # Compute the sum-product
    m = utils.sum_multiply(*arrays,
                           axis=sum_axes,
                           sumaxis=True,
                           keepdims=True) / r
    # Remove extra axes
    m = utils.squeeze_to_dim(m, len(shape_parent))
    return m
Ejemplo n.º 10
0
    def _mask_to_parent(self, index):
        """
        Get the mask with respect to parent[index].

        The mask tells which plate connections are active. The mask is "summed"
        (logical or) and reshaped into the plate shape of the parent. Thus, it
        can't be used for masking messages, because some plates have been summed
        already. This method is used for propagating the mask to parents.
        """
        mask = self._compute_mask_to_parent(index, self.mask)

        # Check the shape of the mask
        plates_to_parent = self._plates_to_parent(index)
        if not utils.is_shape_subset(np.shape(mask), plates_to_parent):
            raise ValueError("In node %s, the mask being sent to "
                             "parent[%d] (%s) has invalid shape: The shape of "
                             "the mask %s is not a sub-shape of the plates of "
                             "the node with respect to the parent %s. It could "
                             "be that this node (%s) is manipulating plates "
                             "but has not overwritten the method "
                             "_compute_mask_to_parent."
                             % (self.name,
                                index,
                                self.parents[index].name,
                                np.shape(mask),
                                plates_to_parent,
                                self.__class__.__name__))

        # "Sum" (i.e., logical or) over the plates that have unit length in 
        # the parent node.
        parent_plates = self.parents[index].plates
        s = utils.axes_to_collapse(np.shape(mask), parent_plates)
        mask = np.any(mask, axis=s, keepdims=True)
        mask = utils.squeeze_to_dim(mask, len(parent_plates))
        return mask
Ejemplo n.º 11
0
def _mask_sum(plates_parent, mask):

    # Compute the mask
    axes = utils.axes_to_collapse(np.shape(mask), plates_parent)
    mask = np.any(mask, axis=axes, keepdims=True)
    mask = utils.squeeze_to_dim(mask, len(plates_parent))
    return mask
Ejemplo n.º 12
0
    def _compute_message_and_mask_to_parent(self, index, m, *u_parents):

        # Normally we don't need to care about masks when computing the
        # message. However, in this node we want to avoid computing huge message
        # arrays so we sum some axis already here. Thus, we need to apply the
        # mask

        mask = self.mask
        parent = self.parents[index]

        # Compute both messages
        for i in range(2):

            # Add extra axes to the message from children
            m[i] = utils.add_trailing_axes(m[i], i+1)

            # List of elements to multiply together
            A = [m[i]]
            for k in range(len(u_parents)):
                if k != index:
                    A.append(u_parents[k][i])

            # Compute the sum over some axes already here in order to avoid huge
            # message matrices.
            m[i] = _message_sum_multiply(parent.plates, parent.dims[i], *A)

        # Compute the mask
        s = utils.axes_to_collapse(np.shape(mask), parent.plates)
        mask = np.any(mask, axis=s, keepdims=True)
        mask = utils.squeeze_to_dim(mask, len(parent.plates))

        return (m, mask)
Ejemplo n.º 13
0
    def _message_to_parent(self, index):

        # Compute the message, check plates, apply mask and sum over some plates
        if index >= len(self.parents):
            raise ValueError("Parent index larger than the number of parents")

        # Compute the message and mask
        (m, mask) = self._get_message_and_mask_to_parent(index)
        mask = utils.squeeze(mask)

        # Plates in the mask
        plates_mask = np.shape(mask)

        # The parent we're sending the message to
        parent = self.parents[index]

        # Compact the message to a proper shape
        for i in range(len(m)):

            # Empty messages are given as None. We can ignore those.
            if m[i] is not None:

                # Plates in the message
                shape_m = np.shape(m[i])
                dim_parent = len(parent.dims[i])
                if dim_parent > 0:
                    plates_m = shape_m[:-dim_parent]
                else:
                    plates_m = shape_m

                # Compute the multiplier (multiply by the number of plates for
                # which the message, the mask and the parent have single
                # plates).  Such a plate is meant to be broadcasted but because
                # the parent has singular plate axis, it won't broadcast (and
                # sum over it), so we need to multiply it.
                plates_self = self._plates_to_parent(index)
                try:
                    r = self._plate_multiplier(plates_self, 
                                               plates_m,
                                               plates_mask,
                                               parent.plates)
                except ValueError:
                    raise ValueError("The plates of the message, the mask and "
                                     "parent[%d] node (%s) are not a "
                                     "broadcastable subset of the plates of "
                                     "this node (%s).  The message has shape "
                                     "%s, meaning plates %s. The mask has "
                                     "plates %s. This node has plates %s with "
                                     "respect to the parent[%d], which has "
                                     "plates %s."
                                     % (index,
                                        parent.name,
                                        self.name,
                                        np.shape(m[i]), 
                                        plates_m, 
                                        plates_mask,
                                        plates_self,
                                        index, 
                                        parent.plates))

                # Add variable axes to the mask
                shape_mask = np.shape(mask) + (1,) * len(parent.dims[i])
                mask_i = np.reshape(mask, shape_mask)

                # Sum over plates that are not in the message nor in the parent
                shape_parent = parent.get_shape(i)
                shape_msg = utils.broadcasted_shape(shape_m, shape_parent)
                axes_mask = utils.axes_to_collapse(shape_mask, shape_msg)
                mask_i = np.sum(mask_i, axis=axes_mask, keepdims=True)

                # Compute the masked message and sum over the plates that the
                # parent does not have.
                axes_msg = utils.axes_to_collapse(shape_msg, shape_parent)
                m[i] = utils.sum_multiply(mask_i, m[i], r, 
                                          axis=axes_msg, 
                                          keepdims=True)

                # Remove leading singular plates if the parent does not have
                # those plate axes.
                m[i] = utils.squeeze_to_dim(m[i], len(shape_parent))

        return m
Ejemplo n.º 14
0
    def _message_to_parent(self, index):

        # Compute the message, check plates, apply mask and sum over some plates
        if index >= len(self.parents):
            raise ValueError("Parent index larger than the number of parents")

        # Compute the message and mask
        (m, mask) = self._get_message_and_mask_to_parent(index)
        mask = utils.squeeze(mask)

        # Plates in the mask
        plates_mask = np.shape(mask)

        # The parent we're sending the message to
        parent = self.parents[index]

        # Compact the message to a proper shape
        for i in range(len(m)):

            # Empty messages are given as None. We can ignore those.
            if m[i] is not None:

                # Plates in the message
                shape_m = np.shape(m[i])
                dim_parent = len(parent.dims[i])
                if dim_parent > 0:
                    plates_m = shape_m[:-dim_parent]
                else:
                    plates_m = shape_m

                # Compute the multiplier (multiply by the number of plates for
                # which the message, the mask and the parent have single
                # plates).  Such a plate is meant to be broadcasted but because
                # the parent has singular plate axis, it won't broadcast (and
                # sum over it), so we need to multiply it.
                plates_self = self._plates_to_parent(index)
                try:
                    r = self._plate_multiplier(plates_self, 
                                               plates_m,
                                               plates_mask,
                                               parent.plates)
                except ValueError:
                    raise ValueError("The plates of the message, the mask and "
                                     "parent[%d] node (%s) are not a "
                                     "broadcastable subset of the plates of "
                                     "this node (%s).  The message has shape "
                                     "%s, meaning plates %s. The mask has "
                                     "plates %s. This node has plates %s with "
                                     "respect to the parent[%d], which has "
                                     "plates %s."
                                     % (index,
                                        parent.name,
                                        self.name,
                                        np.shape(m[i]), 
                                        plates_m, 
                                        plates_mask,
                                        plates_self,
                                        index, 
                                        parent.plates))

                # Add variable axes to the mask
                shape_mask = np.shape(mask) + (1,) * len(parent.dims[i])
                mask_i = np.reshape(mask, shape_mask)

                # Sum over plates that are not in the message nor in the parent
                shape_parent = parent.get_shape(i)
                shape_msg = utils.broadcasted_shape(shape_m, shape_parent)
                axes_mask = utils.axes_to_collapse(shape_mask, shape_msg)
                mask_i = np.sum(mask_i, axis=axes_mask, keepdims=True)

                # Compute the masked message and sum over the plates that the
                # parent does not have.
                axes_msg = utils.axes_to_collapse(shape_msg, shape_parent)
                m[i] = utils.sum_multiply(mask_i, m[i], r, 
                                          axis=axes_msg, 
                                          keepdims=True)

                # Remove leading singular plates if the parent does not have
                # those plate axes.
                m[i] = utils.squeeze_to_dim(m[i], len(shape_parent))

        return m