Ejemplo n.º 1
0
def message_sum_multiply(plates_parent, dims_parent, *arrays):
    """
    Compute message to parent and sum over plates.

    Divide by the plate multiplier.
    """
    # The shape of the full message
    shapes = [np.shape(array) for array in arrays]
    shape_full = misc.broadcasted_shape(*shapes)
    # Find axes that should be summed
    shape_parent = plates_parent + dims_parent
    sum_axes = misc.axes_to_collapse(shape_full, shape_parent)
    # Compute the multiplier for cancelling the
    # plate-multiplier.  Because we are summing over the
    # dimensions already in this function (for efficiency), we
    # need to cancel the effect of the plate-multiplier
    # applied in the message_to_parent function.
    r = 1
    for j in sum_axes:
        if j >= 0 and j < len(plates_parent):
            r *= shape_full[j]
        elif j < 0 and j < -len(dims_parent):
            r *= shape_full[j]
    # Compute the sum-product
    m = misc.sum_multiply(*arrays, axis=sum_axes, sumaxis=True,
                          keepdims=True) / r
    # Remove extra axes
    m = misc.squeeze_to_dim(m, len(shape_parent))
    return m
Ejemplo n.º 2
0
def message_sum_multiply(plates_parent, dims_parent, *arrays):
    """
    Compute message to parent and sum over plates.

    Divide by the plate multiplier.
    """
    # The shape of the full message
    shapes = [np.shape(array) for array in arrays]
    shape_full = misc.broadcasted_shape(*shapes)
    # Find axes that should be summed
    shape_parent = plates_parent + dims_parent
    sum_axes = misc.axes_to_collapse(shape_full, shape_parent)
    # Compute the multiplier for cancelling the
    # plate-multiplier.  Because we are summing over the
    # dimensions already in this function (for efficiency), we
    # need to cancel the effect of the plate-multiplier
    # applied in the message_to_parent function.
    r = 1
    for j in sum_axes:
        if j >= 0 and j < len(plates_parent):
            r *= shape_full[j]
        elif j < 0 and j < -len(dims_parent):
            r *= shape_full[j]
    # Compute the sum-product
    m = misc.sum_multiply(*arrays,
                          axis=sum_axes,
                          sumaxis=True,
                          keepdims=True) / r
    # Remove extra axes
    m = misc.squeeze_to_dim(m, len(shape_parent))
    return m
Ejemplo n.º 3
0
    def setup(self):
        """
        This method should be called just before optimization.
        """

        mask = self.X.mask[..., np.newaxis, np.newaxis]

        # Number of plates
        self.N = self.X.plates[0]  # np.sum(mask)

        # Compute the sum <XX> over plates
        self.XX = misc.sum_multiply(self.X.get_moments()[1], mask, axis=(-1, -2), sumaxis=False, keepdims=False)
        # Parent's moments
        self.Lambda = self.X.parents[1].get_moments()[0]
Ejemplo n.º 4
0
            def m_function(*args):
                lpdf = m(*args)
                # Log pdf only contains plate axes!
                plates_m = np.shape(lpdf)
                r = (self.broadcasting_multiplier(plates_self,
                                                  plates_m,
                                                  plates_mask,
                                                  parent.plates) *
                     self.broadcasting_multiplier(self.plates_multiplier,
                                                  multiplier_parent))
                axes_msg = misc.axes_to_collapse(plates_m, parent.plates)
                m[i] = misc.sum_multiply(mask_i, m[i], r,
                                         axis=axes_msg,
                                         keepdims=True)

                # Remove leading singular plates if the parent does not have
                # those plate axes.
                m[i] = misc.squeeze_to_dim(m[i], len(shape_parent))
Ejemplo n.º 5
0
            def m_function(*args):
                lpdf = m(*args)
                # Log pdf only contains plate axes!
                plates_m = np.shape(lpdf)
                r = (self.broadcasting_multiplier(plates_self, plates_m,
                                                  plates_mask, parent.plates) *
                     self.broadcasting_multiplier(self.plates_multiplier,
                                                  multiplier_parent))
                axes_msg = misc.axes_to_collapse(plates_m, parent.plates)
                m[i] = misc.sum_multiply(mask_i,
                                         m[i],
                                         r,
                                         axis=axes_msg,
                                         keepdims=True)

                # Remove leading singular plates if the parent does not have
                # those plate axes.
                m[i] = misc.squeeze_to_dim(m[i], len(shape_parent))
Ejemplo n.º 6
0
    def _computations_for_A_and_X(self, XpXn, XpXp):

        # Get moments of B and S
        (B, BB) = self.B_node.get_moments()
        CovB = BB - B[..., :, :, None, None] * B[..., None, None, :, :]

        u_S = self.S_node.get_moments()
        S = u_S[0]
        SS = u_S[1]

        #
        # Expectations with respect to A and X
        #

        # TODO/FIXME: If S and B have overlapping plates, then these will give
        # wrong results, because those plates of S are summed before multiplying
        # by the plates of B. There should be some "smart einsum" function which
        # would compute sum-multiplys intelligently given a number of inputs.

        # Compute: \sum_n <A_n> <x_{n-1} x_n^T>
        # Axes: (N, D, D, D, K)
        S_XpXn = misc.sum_multiply(
            S[..., None, None, :], XpXn[..., :, None, :, :, None], axis=(-3, -2, -1), sumaxis=False
        )
        A_XpXn = misc.sum_multiply(B[..., :, :, None, :], S_XpXn[..., :, :, :], axis=(-4, -2), sumaxis=False)

        # Compute: \sum_n <A_n> <x_{n-1} x_{n-1}^T> <A_n>^T
        # Axes: (N, D, D, D, K, D, K)
        SS_XpXp = misc.sum_multiply(
            SS[..., None, :, None, :], XpXp[..., None, :, None, :, None], axis=(-4, -3, -2, -1), sumaxis=False
        )
        B_SS_XpXp = misc.sum_multiply(
            B[..., :, :, :, None, None], SS_XpXp[..., :, :, :, :], axis=(-4, -3), sumaxis=True
        )
        A_XpXp_A = misc.sum_multiply(B_SS_XpXp[..., :, None, :, :], B[..., None, :, :, :], axis=(-4, -3), sumaxis=False)

        # Compute: \sum_n tr(CovA_n <x_{n-1} x_{n-1}^T>)
        # Axes: (D,D,K,D,K)
        CovA_XpXp = misc.sum_multiply(CovB, SS_XpXp, axis=(-5,), sumaxis=False)

        return (A_XpXn, A_XpXp_A, CovA_XpXp)
Ejemplo n.º 7
0
        def sum_plates(V, plates):
            ones = np.ones(np.shape(Q))
            r = self.node_X.broadcasting_multiplier(plates, np.shape(V)[:-2])

            return r * misc.sum_multiply(V, ones, axis=(-1, -2), sumaxis=False, keepdims=False)
Ejemplo n.º 8
0
    def _message_to_parent(self, index):

        # Compute the message, check plates, apply mask and sum over some plates
        if index >= len(self.parents):
            raise ValueError("Parent index larger than the number of parents")

        # Compute the message and mask
        (m, mask) = self._get_message_and_mask_to_parent(index)
        mask = misc.squeeze(mask)

        # Plates in the mask
        plates_mask = np.shape(mask)

        # The parent we're sending the message to
        parent = self.parents[index]

        # Compact the message to a proper shape
        for i in range(len(m)):

            # Empty messages are given as None. We can ignore those.
            if m[i] is not None:

                # Plates in the message
                shape_m = np.shape(m[i])
                dim_parent = len(parent.dims[i])
                if dim_parent > 0:
                    plates_m = shape_m[:-dim_parent]
                else:
                    plates_m = shape_m

                # Compute the multiplier (multiply by the number of plates for
                # which the message, the mask and the parent have single
                # plates).  Such a plate is meant to be broadcasted but because
                # the parent has singular plate axis, it won't broadcast (and
                # sum over it), so we need to multiply it.
                plates_self = self._plates_to_parent(index)
                try:
                    r = self._plate_multiplier(plates_self, 
                                               plates_m,
                                               plates_mask,
                                               parent.plates)
                except ValueError:
                    raise ValueError("The plates of the message, the mask and "
                                     "parent[%d] node (%s) are not a "
                                     "broadcastable subset of the plates of "
                                     "this node (%s).  The message has shape "
                                     "%s, meaning plates %s. The mask has "
                                     "plates %s. This node has plates %s with "
                                     "respect to the parent[%d], which has "
                                     "plates %s."
                                     % (index,
                                        parent.name,
                                        self.name,
                                        np.shape(m[i]), 
                                        plates_m, 
                                        plates_mask,
                                        plates_self,
                                        index, 
                                        parent.plates))

                # Add variable axes to the mask
                shape_mask = np.shape(mask) + (1,) * len(parent.dims[i])
                mask_i = np.reshape(mask, shape_mask)

                # Sum over plates that are not in the message nor in the parent
                shape_parent = parent.get_shape(i)
                shape_msg = misc.broadcasted_shape(shape_m, shape_parent)
                axes_mask = misc.axes_to_collapse(shape_mask, shape_msg)
                mask_i = np.sum(mask_i, axis=axes_mask, keepdims=True)

                # Compute the masked message and sum over the plates that the
                # parent does not have.
                axes_msg = misc.axes_to_collapse(shape_msg, shape_parent)
                m[i] = misc.sum_multiply(mask_i, m[i], r, 
                                         axis=axes_msg, 
                                         keepdims=True)

                # Remove leading singular plates if the parent does not have
                # those plate axes.
                m[i] = misc.squeeze_to_dim(m[i], len(shape_parent))

        return m