Beispiel #1
0
    def _mask_to_parent(self, index):
        """
        Get the mask with respect to parent[index].

        The mask tells which plate connections are active. The mask is "summed"
        (logical or) and reshaped into the plate shape of the parent. Thus, it
        can't be used for masking messages, because some plates have been summed
        already. This method is used for propagating the mask to parents.
        """
        mask = self._compute_mask_to_parent(index, self.mask)

        # Check the shape of the mask
        plates_to_parent = self._plates_to_parent(index)
        if not misc.is_shape_subset(np.shape(mask), plates_to_parent):
            raise ValueError("In node %s, the mask being sent to "
                             "parent[%d] (%s) has invalid shape: The shape of "
                             "the mask %s is not a sub-shape of the plates of "
                             "the node with respect to the parent %s. It could "
                             "be that this node (%s) is manipulating plates "
                             "but has not overwritten the method "
                             "_compute_mask_to_parent."
                             % (self.name,
                                index,
                                self.parents[index].name,
                                np.shape(mask),
                                plates_to_parent,
                                self.__class__.__name__))

        # "Sum" (i.e., logical or) over the plates that have unit length in 
        # the parent node.
        parent_plates = self.parents[index].plates
        s = misc.axes_to_collapse(np.shape(mask), parent_plates)
        mask = np.any(mask, axis=s, keepdims=True)
        mask = misc.squeeze_to_dim(mask, len(parent_plates))
        return mask
Beispiel #2
0
    def _mask_to_parent(self, index):
        """
        Get the mask with respect to parent[index].

        The mask tells which plate connections are active. The mask is "summed"
        (logical or) and reshaped into the plate shape of the parent. Thus, it
        can't be used for masking messages, because some plates have been summed
        already. This method is used for propagating the mask to parents.
        """
        mask = self._compute_mask_to_parent(index, self.mask)

        # Check the shape of the mask
        plates_to_parent = self._plates_to_parent(index)
        if not misc.is_shape_subset(np.shape(mask), plates_to_parent):
            raise ValueError(
                "In node %s, the mask being sent to "
                "parent[%d] (%s) has invalid shape: The shape of "
                "the mask %s is not a sub-shape of the plates of "
                "the node with respect to the parent %s. It could "
                "be that this node (%s) is manipulating plates "
                "but has not overwritten the method "
                "_compute_mask_to_parent." %
                (self.name, index, self.parents[index].name, np.shape(mask),
                 plates_to_parent, self.__class__.__name__))

        # "Sum" (i.e., logical or) over the plates that have unit length in
        # the parent node.
        parent_plates = self.parents[index].plates
        s = misc.axes_to_collapse(np.shape(mask), parent_plates)
        mask = np.any(mask, axis=s, keepdims=True)
        mask = misc.squeeze_to_dim(mask, len(parent_plates))
        return mask
Beispiel #3
0
def message_sum_multiply(plates_parent, dims_parent, *arrays):
    """
    Compute message to parent and sum over plates.

    Divide by the plate multiplier.
    """
    # The shape of the full message
    shapes = [np.shape(array) for array in arrays]
    shape_full = misc.broadcasted_shape(*shapes)
    # Find axes that should be summed
    shape_parent = plates_parent + dims_parent
    sum_axes = misc.axes_to_collapse(shape_full, shape_parent)
    # Compute the multiplier for cancelling the
    # plate-multiplier.  Because we are summing over the
    # dimensions already in this function (for efficiency), we
    # need to cancel the effect of the plate-multiplier
    # applied in the message_to_parent function.
    r = 1
    for j in sum_axes:
        if j >= 0 and j < len(plates_parent):
            r *= shape_full[j]
        elif j < 0 and j < -len(dims_parent):
            r *= shape_full[j]
    # Compute the sum-product
    m = misc.sum_multiply(*arrays,
                          axis=sum_axes,
                          sumaxis=True,
                          keepdims=True) / r
    # Remove extra axes
    m = misc.squeeze_to_dim(m, len(shape_parent))
    return m
Beispiel #4
0
def message_sum_multiply(plates_parent, dims_parent, *arrays):
    """
    Compute message to parent and sum over plates.

    Divide by the plate multiplier.
    """
    # The shape of the full message
    shapes = [np.shape(array) for array in arrays]
    shape_full = misc.broadcasted_shape(*shapes)
    # Find axes that should be summed
    shape_parent = plates_parent + dims_parent
    sum_axes = misc.axes_to_collapse(shape_full, shape_parent)
    # Compute the multiplier for cancelling the
    # plate-multiplier.  Because we are summing over the
    # dimensions already in this function (for efficiency), we
    # need to cancel the effect of the plate-multiplier
    # applied in the message_to_parent function.
    r = 1
    for j in sum_axes:
        if j >= 0 and j < len(plates_parent):
            r *= shape_full[j]
        elif j < 0 and j < -len(dims_parent):
            r *= shape_full[j]
    # Compute the sum-product
    m = misc.sum_multiply(*arrays, axis=sum_axes, sumaxis=True,
                          keepdims=True) / r
    # Remove extra axes
    m = misc.squeeze_to_dim(m, len(shape_parent))
    return m
Beispiel #5
0
        def _compute_weights_to_parent(self, index, weights):
            # Idea: Reshape the message array such that every other axis
            # will be summed and every other kept.

            # Make plates equal length
            plates = self._plates_to_parent(index)
            shape_m = np.shape(weights)
            (plates, tiles_m, shape_m) = misc.make_equal_length(plates, tiles, shape_m)

            # Handle broadcasting rules for axes that have unit length in
            # the message (although the plate may be non-unit length). Also,
            # compute the corresponding broadcasting_multiplier.
            plates = list(plates)
            tiles_m = list(tiles_m)
            for j in range(len(plates)):
                if shape_m[j] == 1:
                    plates[j] = 1
                    tiles_m[j] = 1

            # Combine the tuples by picking every other from tiles_ind and
            # every other from shape
            shape = functools.reduce(lambda x, y: x + y, zip(tiles_m, plates))
            # ..and reshape the array, that is, every other axis corresponds
            # to tiles and every other to plates/dimensions in parents
            weights = np.reshape(weights, shape)

            # Sum over every other axis
            axes = tuple(range(0, len(shape), 2))
            weights = np.sum(weights, axis=axes)

            # Remove extra leading axes
            ndim_parent = len(self.parents[index].plates)
            weights = misc.squeeze_to_dim(weights, ndim_parent)

            return weights
Beispiel #6
0
        def _compute_message_to_parent(self, index, m, u_X):
            m = list(m)
            for ind in range(len(m)):

                # Idea: Reshape the message array such that every other axis
                # will be summed and every other kept.
                
                shape_ind = self._plates_to_parent(index) + self.dims[ind]
                # Add variable dimensions to tiles
                tiles_ind = tiles + (1,)*len(self.dims[ind])

                # Make shape tuples equal length
                shape_m = np.shape(m[ind])
                (tiles_ind, shape, shape_m) = misc.make_equal_length(tiles_ind,
                                                                     shape_ind,
                                                                     shape_m)

                # Handle broadcasting rules for axes that have unit length in
                # the message (although the plate may be non-unit length). Also,
                # compute the corresponding plate_multiplier.
                r = 1
                shape = list(shape)
                tiles_ind = list(tiles_ind)
                for j in range(len(shape)):
                    if shape_m[j] == 1:
                        r *= tiles_ind[j]
                        shape[j] = 1
                        tiles_ind[j] = 1

                # Combine the tuples by picking every other from tiles_ind and
                # every other from shape
                shape = functools.reduce(lambda x,y: x+y,
                                         zip(tiles_ind, shape))
                # ..and reshape the array, that is, every other axis corresponds
                # to tiles and every other to plates/dimensions in parents
                m[ind] = np.reshape(m[ind], shape)

                # Sum over every other axis
                axes = tuple(range(0,len(shape),2))
                m[ind] = r * np.sum(m[ind], axis=axes)

                # Remove extra leading axes
                ndim_parent = len(self.parents[index].get_shape(ind))
                m[ind] = misc.squeeze_to_dim(m[ind], ndim_parent)
            
            return m
Beispiel #7
0
        def _compute_message_to_parent(self, index, m, u_X):
            m = list(m)
            for ind in range(len(m)):

                # Idea: Reshape the message array such that every other axis
                # will be summed and every other kept.

                shape_ind = self._plates_to_parent(index) + self.dims[ind]
                # Add variable dimensions to tiles
                tiles_ind = tiles + (1, ) * len(self.dims[ind])

                # Make shape tuples equal length
                shape_m = np.shape(m[ind])
                (tiles_ind, shape,
                 shape_m) = misc.make_equal_length(tiles_ind, shape_ind,
                                                   shape_m)

                # Handle broadcasting rules for axes that have unit length in
                # the message (although the plate may be non-unit length). Also,
                # compute the corresponding broadcasting multiplier.
                r = 1
                shape = list(shape)
                tiles_ind = list(tiles_ind)
                for j in range(len(shape)):
                    if shape_m[j] == 1:
                        r *= tiles_ind[j]
                        shape[j] = 1
                        tiles_ind[j] = 1

                # Combine the tuples by picking every other from tiles_ind and
                # every other from shape
                shape = functools.reduce(lambda x, y: x + y,
                                         zip(tiles_ind, shape))
                # ..and reshape the array, that is, every other axis corresponds
                # to tiles and every other to plates/dimensions in parents
                m[ind] = np.reshape(m[ind], shape)

                # Sum over every other axis
                axes = tuple(range(0, len(shape), 2))
                m[ind] = r * np.sum(m[ind], axis=axes)

                # Remove extra leading axes
                ndim_parent = len(self.parents[index].get_shape(ind))
                m[ind] = misc.squeeze_to_dim(m[ind], ndim_parent)

            return m
Beispiel #8
0
            def m_function(*args):
                lpdf = m(*args)
                # Log pdf only contains plate axes!
                plates_m = np.shape(lpdf)
                r = (self.broadcasting_multiplier(plates_self,
                                                  plates_m,
                                                  plates_mask,
                                                  parent.plates) *
                     self.broadcasting_multiplier(self.plates_multiplier,
                                                  multiplier_parent))
                axes_msg = misc.axes_to_collapse(plates_m, parent.plates)
                m[i] = misc.sum_multiply(mask_i, m[i], r,
                                         axis=axes_msg,
                                         keepdims=True)

                # Remove leading singular plates if the parent does not have
                # those plate axes.
                m[i] = misc.squeeze_to_dim(m[i], len(shape_parent))
Beispiel #9
0
            def m_function(*args):
                lpdf = m(*args)
                # Log pdf only contains plate axes!
                plates_m = np.shape(lpdf)
                r = (self.broadcasting_multiplier(plates_self, plates_m,
                                                  plates_mask, parent.plates) *
                     self.broadcasting_multiplier(self.plates_multiplier,
                                                  multiplier_parent))
                axes_msg = misc.axes_to_collapse(plates_m, parent.plates)
                m[i] = misc.sum_multiply(mask_i,
                                         m[i],
                                         r,
                                         axis=axes_msg,
                                         keepdims=True)

                # Remove leading singular plates if the parent does not have
                # those plate axes.
                m[i] = misc.squeeze_to_dim(m[i], len(shape_parent))
Beispiel #10
0
        def _compute_mask_to_parent(self, index, mask):
            # Idea: Reshape the message array such that every other axis
            # will be summed and every other kept.

            # Make plates equal length
            plates = self._plates_to_parent(index)
            shape_m = np.shape(mask)
            (plates, tiles_m, shape_m) = misc.make_equal_length(plates, 
                                                                tiles,
                                                                shape_m)
            
            # Handle broadcasting rules for axes that have unit length in
            # the message (although the plate may be non-unit length). Also,
            # compute the corresponding broadcasting_multiplier.
            plates = list(plates)
            tiles_m = list(tiles_m)
            for j in range(len(plates)):
                if shape_m[j] == 1:
                    plates[j] = 1
                    tiles_m[j] = 1
                    
            # Combine the tuples by picking every other from tiles_ind and
            # every other from shape
            shape = functools.reduce(lambda x,y: x+y,
                                     zip(tiles_m, plates))
            # ..and reshape the array, that is, every other axis corresponds
            # to tiles and every other to plates/dimensions in parents
            mask = np.reshape(mask, shape)

            # Sum over every other axis
            axes = tuple(range(0,len(shape),2))
            mask = np.any(mask, axis=axes)

            # Remove extra leading axes
            ndim_parent = len(self.parents[index].plates)
            mask = misc.squeeze_to_dim(mask, ndim_parent)

            return mask
Beispiel #11
0
    def _message_to_parent(self, index):

        # Compute the message, check plates, apply mask and sum over some plates
        if index >= len(self.parents):
            raise ValueError("Parent index larger than the number of parents")

        # Compute the message and mask
        (m, mask) = self._get_message_and_mask_to_parent(index)
        mask = misc.squeeze(mask)

        # Plates in the mask
        plates_mask = np.shape(mask)

        # The parent we're sending the message to
        parent = self.parents[index]

        # Compact the message to a proper shape
        for i in range(len(m)):

            # Empty messages are given as None. We can ignore those.
            if m[i] is not None:

                # Plates in the message
                shape_m = np.shape(m[i])
                dim_parent = len(parent.dims[i])
                if dim_parent > 0:
                    plates_m = shape_m[:-dim_parent]
                else:
                    plates_m = shape_m

                # Compute the multiplier (multiply by the number of plates for
                # which the message, the mask and the parent have single
                # plates).  Such a plate is meant to be broadcasted but because
                # the parent has singular plate axis, it won't broadcast (and
                # sum over it), so we need to multiply it.
                plates_self = self._plates_to_parent(index)
                try:
                    r = self._plate_multiplier(plates_self, 
                                               plates_m,
                                               plates_mask,
                                               parent.plates)
                except ValueError:
                    raise ValueError("The plates of the message, the mask and "
                                     "parent[%d] node (%s) are not a "
                                     "broadcastable subset of the plates of "
                                     "this node (%s).  The message has shape "
                                     "%s, meaning plates %s. The mask has "
                                     "plates %s. This node has plates %s with "
                                     "respect to the parent[%d], which has "
                                     "plates %s."
                                     % (index,
                                        parent.name,
                                        self.name,
                                        np.shape(m[i]), 
                                        plates_m, 
                                        plates_mask,
                                        plates_self,
                                        index, 
                                        parent.plates))

                # Add variable axes to the mask
                shape_mask = np.shape(mask) + (1,) * len(parent.dims[i])
                mask_i = np.reshape(mask, shape_mask)

                # Sum over plates that are not in the message nor in the parent
                shape_parent = parent.get_shape(i)
                shape_msg = misc.broadcasted_shape(shape_m, shape_parent)
                axes_mask = misc.axes_to_collapse(shape_mask, shape_msg)
                mask_i = np.sum(mask_i, axis=axes_mask, keepdims=True)

                # Compute the masked message and sum over the plates that the
                # parent does not have.
                axes_msg = misc.axes_to_collapse(shape_msg, shape_parent)
                m[i] = misc.sum_multiply(mask_i, m[i], r, 
                                         axis=axes_msg, 
                                         keepdims=True)

                # Remove leading singular plates if the parent does not have
                # those plate axes.
                m[i] = misc.squeeze_to_dim(m[i], len(shape_parent))

        return m