Esempio n. 1
0
def message_sum_multiply(plates_parent, dims_parent, *arrays):
    """
    Compute message to parent and sum over plates.

    Divide by the plate multiplier.
    """
    # The shape of the full message
    shapes = [np.shape(array) for array in arrays]
    shape_full = utils.broadcasted_shape(*shapes)
    # Find axes that should be summed
    shape_parent = plates_parent + dims_parent
    sum_axes = utils.axes_to_collapse(shape_full, shape_parent)
    # Compute the multiplier for cancelling the
    # plate-multiplier.  Because we are summing over the
    # dimensions already in this function (for efficiency), we
    # need to cancel the effect of the plate-multiplier
    # applied in the message_to_parent function.
    r = 1
    for j in sum_axes:
        if j >= 0 and j < len(plates_parent):
            r *= shape_full[j]
        elif j < 0 and j < -len(dims_parent):
            r *= shape_full[j]
    # Compute the sum-product
    m = utils.sum_multiply(*arrays,
                           axis=sum_axes,
                           sumaxis=True,
                           keepdims=True) / r
    # Remove extra axes
    m = utils.squeeze_to_dim(m, len(shape_parent))
    return m
Esempio n. 2
0
def _message_sum_multiply(plates_parent, dims_parent, *arrays):
    """
    Compute message to parent and sum over plates.
    """
    # The shape of the full message
    shapes = [np.shape(array) for array in arrays]
    shape_full = utils.broadcasted_shape(*shapes)
    # Find axes that should be summed
    shape_parent = plates_parent + dims_parent
    sum_axes = utils.axes_to_collapse(shape_full, shape_parent)
    # Compute the multiplier for cancelling the
    # plate-multiplier.  Because we are summing over the
    # dimensions already in this function (for efficiency), we
    # need to cancel the effect of the plate-multiplier
    # applied in the message_to_parent function.
    r = 1
    for j in sum_axes:
        if j >= 0 and j < len(plates_parent):
            r *= shape_full[j]
        elif j < 0 and j < -len(dims_parent):
            r *= shape_full[j]
    # Compute the sum-product
    m = utils.sum_multiply(*arrays, axis=sum_axes, sumaxis=True,
                           keepdims=True) / r
    # Remove extra axes
    m = utils.squeeze_to_dim(m, len(shape_parent))
    return m
Esempio n. 3
0
    def _plate_multiplier(plates, *args):
        # Check broadcasting of the shapes
        for arg in args:
            utils.broadcasted_shape(plates, arg)

        # Check that each arg-plates are a subset of plates?
        for arg in args:
            if not utils.is_shape_subset(arg, plates):
                raise ValueError("The shapes in args are not a sub-shape of "
                                 "plates.")
            
        r = 1
        for j in range(-len(plates),0):
            mult = True
            for arg in args:
                if not (-j > len(arg) or arg[j] == 1):
                    mult = False
            if mult:
                r *= plates[j]
        return r
Esempio n. 4
0
    def _plate_multiplier(plates, *args):
        # Check broadcasting of the shapes
        for arg in args:
            utils.broadcasted_shape(plates, arg)

        # Check that each arg-plates are a subset of plates?
        for arg in args:
            if not utils.is_shape_subset(arg, plates):
                raise ValueError("The shapes in args are not a sub-shape of "
                                 "plates.")

        r = 1
        for j in range(-len(plates), 0):
            mult = True
            for arg in args:
                if not (-j > len(arg) or arg[j] == 1):
                    mult = False
            if mult:
                r *= plates[j]
        return r
Esempio n. 5
0
    def _plate_multiplier(plates, *args):
        """
        Compute the plate multiplier for given shapes.

        The first shape is compared to all other shapes (using NumPy
        broadcasting rules). All the elements which are non-unit in the first
        shape but 1 in all other shapes are multiplied together.

        This method is used, for instance, for computing a correction factor for
        messages to parents: If this node has non-unit plates that are unit
        plates in the parent, those plates are summed. However, if the message
        has unit axis for that plate, it should be first broadcasted to the
        plates of this node and then summed to the plates of the parent. In
        order to avoid this broadcasting and summing, it is more efficient to
        just multiply by the correct factor. This method computes that
        factor. The first argument is the full plate shape of this node (with
        respect to the parent). The other arguments are the shape of the message
        array and the plates of the parent (with respect to this node).
        """
        
        # Check broadcasting of the shapes
        for arg in args:
            utils.broadcasted_shape(plates, arg)

        # Check that each arg-plates are a subset of plates?
        for arg in args:
            if not utils.is_shape_subset(arg, plates):
                raise ValueError("The shapes in args are not a sub-shape of "
                                 "plates.")
            
        r = 1
        for j in range(-len(plates),0):
            mult = True
            for arg in args:
                # if -j <= len(arg) and arg[j] != 1:
                if not (-j > len(arg) or arg[j] == 1):
                    mult = False
            if mult:
                r *= plates[j]
        return r
Esempio n. 6
0
    def _plate_multiplier(plates, *args):
        """
        Compute the plate multiplier for given shapes.

        The first shape is compared to all other shapes (using NumPy
        broadcasting rules). All the elements which are non-unit in the first
        shape but 1 in all other shapes are multiplied together.

        This method is used, for instance, for computing a correction factor for
        messages to parents: If this node has non-unit plates that are unit
        plates in the parent, those plates are summed. However, if the message
        has unit axis for that plate, it should be first broadcasted to the
        plates of this node and then summed to the plates of the parent. In
        order to avoid this broadcasting and summing, it is more efficient to
        just multiply by the correct factor. This method computes that
        factor. The first argument is the full plate shape of this node (with
        respect to the parent). The other arguments are the shape of the message
        array and the plates of the parent (with respect to this node).
        """
        
        # Check broadcasting of the shapes
        for arg in args:
            utils.broadcasted_shape(plates, arg)

        # Check that each arg-plates are a subset of plates?
        for arg in args:
            if not utils.is_shape_subset(arg, plates):
                raise ValueError("The shapes in args are not a sub-shape of "
                                 "plates.")
            
        r = 1
        for j in range(-len(plates),0):
            mult = True
            for arg in args:
                # if -j <= len(arg) and arg[j] != 1:
                if not (-j > len(arg) or arg[j] == 1):
                    mult = False
            if mult:
                r *= plates[j]
        return r
Esempio n. 7
0
    def _message_from_children(self):
        msg = [np.array(0.0) for i in range(len(self.dims))]
        for (child,index) in self.children:
            m = child._message_to_parent(index)
            for i in range(len(self.dims)):
                if m[i] is not None:
                    # Check broadcasting shapes
                    sh = utils.broadcasted_shape(self.get_shape(i), np.shape(m[i]))
                    try:
                        # Try exploiting broadcasting rules
                        msg[i] += m[i]
                    except ValueError:
                        msg[i] = msg[i] + m[i]

        return msg
Esempio n. 8
0
    def _message_from_children(self):
        msg = [np.zeros(shape) for shape in self.dims]
        #msg = [np.array(0.0) for i in range(len(self.dims))]
        for (child,index) in self.children:
            m = child._message_to_parent(index)
            for i in range(len(self.dims)):
                if m[i] is not None:
                    # Check broadcasting shapes
                    sh = utils.broadcasted_shape(self.get_shape(i), np.shape(m[i]))
                    try:
                        # Try exploiting broadcasting rules
                        msg[i] += m[i]
                    except ValueError:
                        msg[i] = msg[i] + m[i]

        return msg
Esempio n. 9
0
    def __init__(self, *parents, dims=None, plates=None, name=""):

        if dims is None:
            raise Exception("You need to specify the dimensionality of the "
                            "distribution for class %s"
                            % str(self.__class__))

        self.dims = dims
        self.name = name

        # Parents
        self.parents = parents
        # Inform parent nodes
        for (index,parent) in enumerate(self.parents):
            if parent:
                parent._add_child(self, index)

        # Check plates
        parent_plates = [self._plates_from_parent(index) 
                         for index in range(len(self.parents))]
        if any(p is None for p in parent_plates):
            raise ValueError("Method _plates_from_parent returned None")
        
        if plates is None:
            # By default, use the minimum number of plates determined
            # from the parent nodes
            try:
                self.plates = utils.broadcasted_shape(*parent_plates)
            except ValueError:
                raise ValueError("The plates of the parents do not broadcast.")
        else:
            # Use custom plates
            self.plates = plates
            # Check that the parent_plates are a subset of plates.
            for p in parent_plates:
                if not utils.is_shape_subset(p, plates):
                    raise ValueError("The plates of the parents are not "
                                     "subsets of the given plates.")
                                                 

        # By default, ignore all plates
        self.mask = np.array(False)

        # Children
        self.children = list()
Esempio n. 10
0
    def _compute_moments(self, u_Z):
        # Add time axis to p0
        p0 = u_Z[0][..., None, :]
        # Sum joint probability arrays to marginal probability vectors
        zz = u_Z[1]
        p = np.sum(zz, axis=-2)

        # Broadcast p0 and p to same shape, except the time axis
        plates_p0 = np.shape(p0)[:-2]
        plates_p = np.shape(p)[:-2]
        shape = utils.broadcasted_shape(plates_p0, plates_p) + (1, 1)
        p0 = p0 * np.ones(shape)
        p = p * np.ones(shape)

        # Concatenate
        P = np.concatenate((p0, p), axis=-2)

        return [P]
Esempio n. 11
0
 def _total_plates(cls, plates, *parent_plates):
     if plates is None:
         # By default, use the minimum number of plates determined
         # from the parent nodes
         try:
             return utils.broadcasted_shape(*parent_plates)
         except ValueError:
             raise ValueError("The plates of the parents do not broadcast.")
     else:
         # Check that the parent_plates are a subset of plates.
         for (ind, p) in enumerate(parent_plates):
             if not utils.is_shape_subset(p, plates):
                 raise ValueError("The plates %s of the parents "
                                  "are not broadcastable to the given "
                                  "plates %s."
                                  % (p,
                                     plates))
         return plates
    def _compute_moments(self, u_Z):
        # Add time axis to p0
        p0 = u_Z[0][...,None,:]
        # Sum joint probability arrays to marginal probability vectors
        zz = u_Z[1]
        p = np.sum(zz, axis=-2)

        # Broadcast p0 and p to same shape, except the time axis
        plates_p0 = np.shape(p0)[:-2]
        plates_p = np.shape(p)[:-2]
        shape = utils.broadcasted_shape(plates_p0, plates_p) + (1,1)
        p0 = p0 * np.ones(shape)
        p = p * np.ones(shape)

        # Concatenate
        P = np.concatenate((p0,p), axis=-2)

        return [P]
Esempio n. 13
0
 def _total_plates(cls, plates, *parent_plates):
     if plates is None:
         # By default, use the minimum number of plates determined
         # from the parent nodes
         try:
             return utils.broadcasted_shape(*parent_plates)
         except ValueError:
             raise ValueError("The plates of the parents do not broadcast.")
     else:
         # Check that the parent_plates are a subset of plates.
         for (ind, p) in enumerate(parent_plates):
             if not utils.is_shape_subset(p, plates):
                 raise ValueError("The plates %s of the parents "
                                  "are not broadcastable to the given "
                                  "plates %s."
                                  % (p,
                                     plates))
         return plates
Esempio n. 14
0
    def __init__(self, *parents, dims=None, plates=None, name=""):

        if dims is None:
            raise Exception("You need to specify the dimensionality of the "
                            "distribution for class %s" % str(self.__class__))

        self.dims = dims
        self.name = name

        # Parents
        self.parents = parents
        # Inform parent nodes
        for (index, parent) in enumerate(self.parents):
            if parent:
                parent._add_child(self, index)

        # Check plates
        parent_plates = [
            self._plates_from_parent(index)
            for index in range(len(self.parents))
        ]
        if plates is None:
            # By default, use the minimum number of plates determined
            # from the parent nodes
            try:
                self.plates = utils.broadcasted_shape(*parent_plates)
            except ValueError:
                raise ValueError("The plates of the parents do not broadcast.")
        else:
            # Use custom plates
            self.plates = plates
            # Check that the parent_plates are a subset of plates.
            for p in parent_plates:
                if not utils.is_shape_subset(p, plates):
                    raise ValueError("The plates of the parents are not "
                                     "subsets of the given plates.")

        # By default, ignore all plates
        self.mask = np.array(False)

        # Children
        self.children = list()
Esempio n. 15
0
File: dot.py Progetto: vlall/bayespy
    def _message_to_parent(self, index):
        """
        Compute the message and mask to a parent node.
        """

        # Check index
        if index >= len(self.parents):
            raise ValueError("Parent index larger than the number of parents")

        # Get messages from other parents and children
        u_parents = self._message_from_parents(exclude=index)
        m = self._message_from_children()
        mask = self.mask

        # Normally we don't need to care about masks when computing the
        # message. However, in this node we want to avoid computing huge message
        # arrays so we sum some axes already here. Thus, we need to apply the
        # mask.

        parent = self.parents[index]

        #
        # Compute the first message
        #

        msg = [None, None]
        
        # Compute the two messages
        for ind in range(2):

            # The total number of keys for the non-plate dimensions
            N = (ind+1) * self.N_keys

            # Add an array of ones to ensure proper shape and number of
            # plates. Note that this adds an axis for each plate. At the end, we
            # want to remove axes that were created only because of this
            parent_num_dims = len(parent.dims[ind])
            parent_num_plates = len(parent.plates)
            parent_plate_keys = list(range(N + parent_num_plates,
                                           N,
                                           -1))
            parent_dim_keys = self.in_keys[index]
            if ind == 1:
                parent_dim_keys = ([key + self.N_keys
                                    for key in self.in_keys[index]]
                                   + parent_dim_keys)
            args = []
            args.append(np.ones((1,)*parent_num_plates + parent.dims[ind]))
            args.append(parent_plate_keys + parent_dim_keys)

            # This variable counts the maximum number of plates of the
            # arguments, thus it will tell the number of plates in the result
            # (if the artificially added plates above were ignored).
            result_num_plates = 0
            result_plates = ()

            # Mask and its keysr
            mask_num_plates = np.ndim(mask)
            mask_plates = np.shape(mask)
            mask_plate_keys = list(range(N + mask_num_plates, 
                                         N,
                                         -1))
            result_num_plates = max(result_num_plates,
                                    mask_num_plates)
            result_plates = utils.broadcasted_shape(result_plates,
                                                    mask_plates)
            args.append(mask)
            args.append(mask_plate_keys)

            # Moments and keys of other parents
            for (k, u) in enumerate(u_parents):
                if k != index:
                    num_dims = (ind+1) * len(self.in_keys[k])
                    num_plates = np.ndim(u[ind]) - num_dims
                    plates = np.shape(u[ind])[:num_plates]
                    plate_keys = list(range(N + num_plates, 
                                            N,
                                            -1))
                    dim_keys = self.in_keys[k]
                    if ind == 1:
                        dim_keys = ([key + self.N_keys 
                                     for key in self.in_keys[k]]
                                    + dim_keys)
                    args.append(u[ind])
                    args.append(plate_keys + dim_keys)

                    result_num_plates = max(result_num_plates, num_plates)
                    result_plates = utils.broadcasted_shape(result_plates,
                                                            plates)

            # Message and keys from children
            child_num_dims = (ind+1) * len(self.out_keys)
            child_num_plates = np.ndim(m[ind]) - child_num_dims
            child_plates = np.shape(m[ind])[:child_num_plates]
            child_plate_keys = list(range(N + child_num_plates,
                                          N,
                                          -1))
            child_dim_keys = self.out_keys
            if ind == 1:
                child_dim_keys = ([key + self.N_keys
                                   for key in self.out_keys]
                                  + child_dim_keys)
            args.append(m[ind])
            args.append(child_plate_keys + child_dim_keys)

            result_num_plates = max(result_num_plates, child_num_plates)
            result_plates = utils.broadcasted_shape(result_plates,
                                                    child_plates)

            # Output keys, that is, the keys of the parent[index]
            parent_keys = parent_plate_keys + parent_dim_keys

            # Performance trick: Check which axes can be summed because they
            # have length 1 or are non-existing in parent[index]. Thus, remove
            # keys corresponding to unit length axes in parent[index] so that
            # einsum sums over those axes. After computations, these axes must
            # be added back in order to get the correct shape for the message.

            parent_shape = parent.get_shape(ind)
            removed_axes = []
            for j in range(len(parent_keys)):
                if parent_shape[j] == 1:
                    # Remove the key (take into account the number of keys that
                    # have already been removed)
                    del parent_keys[j-len(removed_axes)]
                    removed_axes.append(j)

            args.append(parent_keys)

            # THE BEEF: Compute the message
            msg[ind] = np.einsum(*args)

            # Find the correct shape for the message array
            message_shape = list(np.shape(msg[ind]))
            # First, add back the axes with length 1
            for ax in removed_axes:
                message_shape.insert(ax, 1)
            # Second, remove leading axes for plates that were not present in
            # the child nor other parents' messages. This is not really
            # necessary, but it is just elegant to remove the leading unit
            # length axes that we added artificially at the beginning just
            # because we wanted the key mapping to be simple.
            if parent_num_plates > result_num_plates:
                del message_shape[:(parent_num_plates-result_num_plates)]
            # Then, the actual reshaping
            msg[ind] = np.reshape(msg[ind], message_shape)

            # Apply plate multiplier: If this node has non-unit plates that are
            # unit plates in the parent, those plates are summed. However, if
            # the message has unit axis for that plate, it should be first
            # broadcasted to the plates of this node and then summed to the
            # plates of the parent. In order to avoid this broadcasting and
            # summing, it is more efficient to just multiply by the correct
            # factor.
            r = self._plate_multiplier(self.plates, 
                                       result_plates,
                                       parent.plates)
            if r != 1:
                msg[ind] *= r
        
        return msg
Esempio n. 16
0
    def _message_to_parent(self, index):

        # Compute the message, check plates, apply mask and sum over some plates
        if index >= len(self.parents):
            raise ValueError("Parent index larger than the number of parents")

        # Compute the message and mask
        (m, mask) = self._get_message_and_mask_to_parent(index)
        mask = utils.squeeze(mask)

        # Plates in the mask
        plates_mask = np.shape(mask)

        # The parent we're sending the message to
        parent = self.parents[index]

        # Compact the message to a proper shape
        for i in range(len(m)):

            # Empty messages are given as None. We can ignore those.
            if m[i] is not None:

                # Plates in the message
                shape_m = np.shape(m[i])
                dim_parent = len(parent.dims[i])
                if dim_parent > 0:
                    plates_m = shape_m[:-dim_parent]
                else:
                    plates_m = shape_m

                # Compute the multiplier (multiply by the number of plates for
                # which the message, the mask and the parent have single
                # plates).  Such a plate is meant to be broadcasted but because
                # the parent has singular plate axis, it won't broadcast (and
                # sum over it), so we need to multiply it.
                plates_self = self._plates_to_parent(index)
                try:
                    r = self._plate_multiplier(plates_self, 
                                               plates_m,
                                               plates_mask,
                                               parent.plates)
                except ValueError:
                    raise ValueError("The plates of the message, the mask and "
                                     "parent[%d] node (%s) are not a "
                                     "broadcastable subset of the plates of "
                                     "this node (%s).  The message has shape "
                                     "%s, meaning plates %s. The mask has "
                                     "plates %s. This node has plates %s with "
                                     "respect to the parent[%d], which has "
                                     "plates %s."
                                     % (index,
                                        parent.name,
                                        self.name,
                                        np.shape(m[i]), 
                                        plates_m, 
                                        plates_mask,
                                        plates_self,
                                        index, 
                                        parent.plates))

                # Add variable axes to the mask
                shape_mask = np.shape(mask) + (1,) * len(parent.dims[i])
                mask_i = np.reshape(mask, shape_mask)

                # Sum over plates that are not in the message nor in the parent
                shape_parent = parent.get_shape(i)
                shape_msg = utils.broadcasted_shape(shape_m, shape_parent)
                axes_mask = utils.axes_to_collapse(shape_mask, shape_msg)
                mask_i = np.sum(mask_i, axis=axes_mask, keepdims=True)

                # Compute the masked message and sum over the plates that the
                # parent does not have.
                axes_msg = utils.axes_to_collapse(shape_msg, shape_parent)
                m[i] = utils.sum_multiply(mask_i, m[i], r, 
                                          axis=axes_msg, 
                                          keepdims=True)

                # Remove leading singular plates if the parent does not have
                # those plate axes.
                m[i] = utils.squeeze_to_dim(m[i], len(shape_parent))

        return m
Esempio n. 17
0
    def _message_to_parent(self, index):
        """
        Compute the message and mask to a parent node.
        """

        # Check index
        if index >= len(self.parents):
            raise ValueError("Parent index larger than the number of parents")

        # Get messages from other parents and children
        u_parents = self._message_from_parents(exclude=index)
        m = self._message_from_children()
        mask = self.mask

        # Normally we don't need to care about masks when computing the
        # message. However, in this node we want to avoid computing huge message
        # arrays so we sum some axes already here. Thus, we need to apply the
        # mask.

        parent = self.parents[index]

        #
        # Compute the first message
        #

        msg = [None, None]

        # Compute the two messages
        for ind in range(2):

            # The total number of keys for the non-plate dimensions
            N = (ind + 1) * self.N_keys

            # Add an array of ones to ensure proper shape and number of
            # plates. Note that this adds an axis for each plate. At the end, we
            # want to remove axes that were created only because of this
            parent_num_dims = len(parent.dims[ind])
            parent_num_plates = len(parent.plates)
            parent_plate_keys = list(range(N + parent_num_plates, N, -1))
            parent_dim_keys = self.in_keys[index]
            if ind == 1:
                parent_dim_keys = (
                    [key + self.N_keys
                     for key in self.in_keys[index]] + parent_dim_keys)
            args = []
            args.append(np.ones((1, ) * parent_num_plates + parent.dims[ind]))
            args.append(parent_plate_keys + parent_dim_keys)

            # This variable counts the maximum number of plates of the
            # arguments, thus it will tell the number of plates in the result
            # (if the artificially added plates above were ignored).
            result_num_plates = 0
            result_plates = ()

            # Mask and its keysr
            mask_num_plates = np.ndim(mask)
            mask_plates = np.shape(mask)
            mask_plate_keys = list(range(N + mask_num_plates, N, -1))
            result_num_plates = max(result_num_plates, mask_num_plates)
            result_plates = utils.broadcasted_shape(result_plates, mask_plates)
            args.append(mask)
            args.append(mask_plate_keys)

            # Moments and keys of other parents
            for (k, u) in enumerate(u_parents):
                if k != index:
                    num_dims = (ind + 1) * len(self.in_keys[k])
                    num_plates = np.ndim(u[ind]) - num_dims
                    plates = np.shape(u[ind])[:num_plates]
                    plate_keys = list(range(N + num_plates, N, -1))
                    dim_keys = self.in_keys[k]
                    if ind == 1:
                        dim_keys = (
                            [key + self.N_keys
                             for key in self.in_keys[k]] + dim_keys)
                    args.append(u[ind])
                    args.append(plate_keys + dim_keys)

                    result_num_plates = max(result_num_plates, num_plates)
                    result_plates = utils.broadcasted_shape(
                        result_plates, plates)

            # Message and keys from children
            child_num_dims = (ind + 1) * len(self.out_keys)
            child_num_plates = np.ndim(m[ind]) - child_num_dims
            child_plates = np.shape(m[ind])[:child_num_plates]
            child_plate_keys = list(range(N + child_num_plates, N, -1))
            child_dim_keys = self.out_keys
            if ind == 1:
                child_dim_keys = ([key + self.N_keys
                                   for key in self.out_keys] + child_dim_keys)
            args.append(m[ind])
            args.append(child_plate_keys + child_dim_keys)

            result_num_plates = max(result_num_plates, child_num_plates)
            result_plates = utils.broadcasted_shape(result_plates,
                                                    child_plates)

            # Output keys, that is, the keys of the parent[index]
            parent_keys = parent_plate_keys + parent_dim_keys

            # Performance trick: Check which axes can be summed because they
            # have length 1 or are non-existing in parent[index]. Thus, remove
            # keys corresponding to unit length axes in parent[index] so that
            # einsum sums over those axes. After computations, these axes must
            # be added back in order to get the correct shape for the message.

            parent_shape = parent.get_shape(ind)
            removed_axes = []
            for j in range(len(parent_keys)):
                if parent_shape[j] == 1:
                    # Remove the key (take into account the number of keys that
                    # have already been removed)
                    del parent_keys[j - len(removed_axes)]
                    removed_axes.append(j)

            args.append(parent_keys)

            # THE BEEF: Compute the message
            msg[ind] = np.einsum(*args)

            # Find the correct shape for the message array
            message_shape = list(np.shape(msg[ind]))
            # First, add back the axes with length 1
            for ax in removed_axes:
                message_shape.insert(ax, 1)
            # Second, remove leading axes for plates that were not present in
            # the child nor other parents' messages. This is not really
            # necessary, but it is just elegant to remove the leading unit
            # length axes that we added artificially at the beginning just
            # because we wanted the key mapping to be simple.
            if parent_num_plates > result_num_plates:
                del message_shape[:(parent_num_plates - result_num_plates)]
            # Then, the actual reshaping
            msg[ind] = np.reshape(msg[ind], message_shape)

            # Apply plate multiplier: If this node has non-unit plates that are
            # unit plates in the parent, those plates are summed. However, if
            # the message has unit axis for that plate, it should be first
            # broadcasted to the plates of this node and then summed to the
            # plates of the parent. In order to avoid this broadcasting and
            # summing, it is more efficient to just multiply by the correct
            # factor.
            r = self._plate_multiplier(self.plates, result_plates,
                                       parent.plates)
            if r != 1:
                msg[ind] *= r

        return msg
    def _compute_phi_from_parents(u_mu, u_Lambda, u_B, u_S, u_v, u_N):
        """
        Compute the natural parameters using parents' moments.

        Parameters
        ----------
        u_parents : list of list of arrays
           List of parents' lists of moments.

        Returns
        -------
        phi : list of arrays
           Natural parameters.
        dims : tuple
           Shape of the variable part of phi.

        """

        # Dimensionality of the Gaussian states
        D = np.shape(u_mu[0])[-1]

        # Number of time instances in the process
        N = u_N[0]

        # Helpful variables (show shapes in comments)
        mu = u_mu[0]         # (..., D)
        Lambda = u_Lambda[0] # (..., D, D)
        B = u_B[0]           # (..., D, D, K)
        BB = u_B[1]          # (..., D, D, K, D, K)
        S = u_S[0]           # (..., N-1, K) or (..., 1, K)
        SS = u_S[1]          # (..., N-1, K, K)
        v = u_v[0]           # (..., N-1, D) or (..., 1, D)

        # TODO/FIXME: Take into account plates!
        plates_phi0 = utils.broadcasted_shape(np.shape(mu)[:-1],
                                              np.shape(Lambda)[:-2])
        plates_phi1 = utils.broadcasted_shape(np.shape(Lambda)[:-2],
                                              np.shape(v)[:-2],
                                              np.shape(BB)[:-5],
                                              np.shape(SS)[:-3])
        plates_phi2 = utils.broadcasted_shape(np.shape(B)[:-3],
                                              np.shape(S)[:-2],
                                              np.shape(v)[:-2])
        phi0 = np.zeros(plates_phi0 + (N,D))
        phi1 = np.zeros(plates_phi1 + (N,D,D))
        phi2 = np.zeros(plates_phi2 + (N-1,D,D))

        # Parameters for x0
        phi0[...,0,:] = np.einsum('...ik,...k->...i', Lambda, mu)
        phi1[...,0,:,:] = Lambda


        # Diagonal blocks: -0.5 * (V_i + A_{i+1}' * V_{i+1} * A_{i+1})
        phi1[..., 1:, :, :] = v[...,np.newaxis]*np.identity(D)
        if np.ndim(v) >= 2 and np.shape(v)[-2] > 1:
            raise Exception("This implementation is not efficient if "
                            "innovation noise is time-dependent.")
            phi1[..., :-1, :, :] += np.einsum('...dikjl,...kl,...d->...ij', 
                                              BB[...,None,:,:,:,:,:],
                                              SS,
                                              v)
        else:
            # We know that S does not have the D plate so we can sum that plate
            # axis out
            v_BB = np.einsum('...dikjl,...d->...ikjl',
                             BB[...,None,:,:,:,:,:],
                             v)
            phi1[..., :-1, :, :] += np.einsum('...ikjl,...kl->...ij', 
                                              v_BB,
                                              SS)
            
        #phi1[..., :-1, :, :] += np.einsum('...kij,...k->...ij', AA, v)
        phi1 *= -0.5

        # Super-diagonal blocks: 0.5 * A.T * V
        # However, don't multiply by 0.5 because there are both super- and
        # sub-diagonal blocks (sum them together)
        phi2[..., :, :, :] = np.einsum('...jik,...k,...j->...ij', 
                                       B[...,None,:,:,:],
                                       S,
                                       v)
        #phi2[..., :, :, :] = np.einsum('...ji,...j->...ij', A, v)

        return (phi0, phi1, phi2)
Esempio n. 19
0
    def _message_to_parent(self, index):

        # Compute the message, check plates, apply mask and sum over some plates
        if index >= len(self.parents):
            raise ValueError("Parent index larger than the number of parents")

        # Compute the message and mask
        (m, mask) = self._get_message_and_mask_to_parent(index)
        mask = utils.squeeze(mask)

        # Plates in the mask
        plates_mask = np.shape(mask)

        # The parent we're sending the message to
        parent = self.parents[index]

        # Compact the message to a proper shape
        for i in range(len(m)):

            # Empty messages are given as None. We can ignore those.
            if m[i] is not None:

                # Plates in the message
                shape_m = np.shape(m[i])
                dim_parent = len(parent.dims[i])
                if dim_parent > 0:
                    plates_m = shape_m[:-dim_parent]
                else:
                    plates_m = shape_m

                # Compute the multiplier (multiply by the number of plates for
                # which the message, the mask and the parent have single
                # plates).  Such a plate is meant to be broadcasted but because
                # the parent has singular plate axis, it won't broadcast (and
                # sum over it), so we need to multiply it.
                plates_self = self._plates_to_parent(index)
                try:
                    r = self._plate_multiplier(plates_self, 
                                               plates_m,
                                               plates_mask,
                                               parent.plates)
                except ValueError:
                    raise ValueError("The plates of the message, the mask and "
                                     "parent[%d] node (%s) are not a "
                                     "broadcastable subset of the plates of "
                                     "this node (%s).  The message has shape "
                                     "%s, meaning plates %s. The mask has "
                                     "plates %s. This node has plates %s with "
                                     "respect to the parent[%d], which has "
                                     "plates %s."
                                     % (index,
                                        parent.name,
                                        self.name,
                                        np.shape(m[i]), 
                                        plates_m, 
                                        plates_mask,
                                        plates_self,
                                        index, 
                                        parent.plates))

                # Add variable axes to the mask
                shape_mask = np.shape(mask) + (1,) * len(parent.dims[i])
                mask_i = np.reshape(mask, shape_mask)

                # Sum over plates that are not in the message nor in the parent
                shape_parent = parent.get_shape(i)
                shape_msg = utils.broadcasted_shape(shape_m, shape_parent)
                axes_mask = utils.axes_to_collapse(shape_mask, shape_msg)
                mask_i = np.sum(mask_i, axis=axes_mask, keepdims=True)

                # Compute the masked message and sum over the plates that the
                # parent does not have.
                axes_msg = utils.axes_to_collapse(shape_msg, shape_parent)
                m[i] = utils.sum_multiply(mask_i, m[i], r, 
                                          axis=axes_msg, 
                                          keepdims=True)

                # Remove leading singular plates if the parent does not have
                # those plate axes.
                m[i] = utils.squeeze_to_dim(m[i], len(shape_parent))

        return m