Exemple #1
0
    def _compute_moments(self, *u_parents):
        # TODO/FIXME: Unfortunately, np.concatenate doesn't support
        # broadcasting but moment messages may use broadcasting.
        #
        # WORKAROUND: Broadcast the arrays explcitly to have same shape
        # except for the concatenated axis.
        u = []
        for i in range(len(self.dims)):
            # Fix plate axis to array axis
            axis = self._axis - len(self.dims[i])
            # Find broadcasted shape
            ui_parents = [u_parent[i] for u_parent in u_parents]
            shapes = [list(np.shape(uip)) for uip in ui_parents]
            for i in range(len(shapes)):
                if len(shapes[i]) >= abs(axis):
                    shapes[i][axis] = 1
            ## shapes = [np.shape(uip[:axis]) + (1,) + np.shape(uip[(axis+1)])
            ##           if np.ndim(uip) >= abs(self._axis) else
            ##           np.shape(uip)
            ##           for uip in ui_parents]
            bc_shape = misc.broadcasted_shape(*shapes)
            # Concatenated axis must be broadcasted explicitly
            bc_shapes = [misc.broadcasted_shape(bc_shape,
                                                (length,) + (1,)*(abs(axis)-1))
                         for length in self._lengths]
            # Broadcast explicitly
            ui_parents = [uip * np.ones(shape)
                          for (uip, shape) in zip(ui_parents, bc_shapes)]
            # Concatenate
            ui = np.concatenate(ui_parents, axis=axis)
            u.append(ui)

        return u
Exemple #2
0
    def _compute_moments(self, *u_parents):
        # TODO/FIXME: Unfortunately, np.concatenate doesn't support
        # broadcasting but moment messages may use broadcasting.
        #
        # WORKAROUND: Broadcast the arrays explcitly to have same shape
        # except for the concatenated axis.
        u = []
        for i in range(len(self.dims)):
            # Fix plate axis to array axis
            axis = self._axis - len(self.dims[i])
            # Find broadcasted shape
            ui_parents = [u_parent[i] for u_parent in u_parents]
            shapes = [list(np.shape(uip)) for uip in ui_parents]
            for i in range(len(shapes)):
                if len(shapes[i]) >= abs(axis):
                    shapes[i][axis] = 1
            ## shapes = [np.shape(uip[:axis]) + (1,) + np.shape(uip[(axis+1)])
            ##           if np.ndim(uip) >= abs(self._axis) else
            ##           np.shape(uip)
            ##           for uip in ui_parents]
            bc_shape = misc.broadcasted_shape(*shapes)
            # Concatenated axis must be broadcasted explicitly
            bc_shapes = [misc.broadcasted_shape(bc_shape,
                                                (length,) + (1,)*(abs(axis)-1))
                         for length in self._lengths]
            # Broadcast explicitly
            ui_parents = [uip * np.ones(shape)
                          for (uip, shape) in zip(ui_parents, bc_shapes)]
            # Concatenate
            ui = np.concatenate(ui_parents, axis=axis)
            u.append(ui)

        return u
Exemple #3
0
def message_sum_multiply(plates_parent, dims_parent, *arrays):
    """
    Compute message to parent and sum over plates.

    Divide by the plate multiplier.
    """
    # The shape of the full message
    shapes = [np.shape(array) for array in arrays]
    shape_full = misc.broadcasted_shape(*shapes)
    # Find axes that should be summed
    shape_parent = plates_parent + dims_parent
    sum_axes = misc.axes_to_collapse(shape_full, shape_parent)
    # Compute the multiplier for cancelling the
    # plate-multiplier.  Because we are summing over the
    # dimensions already in this function (for efficiency), we
    # need to cancel the effect of the plate-multiplier
    # applied in the message_to_parent function.
    r = 1
    for j in sum_axes:
        if j >= 0 and j < len(plates_parent):
            r *= shape_full[j]
        elif j < 0 and j < -len(dims_parent):
            r *= shape_full[j]
    # Compute the sum-product
    m = misc.sum_multiply(*arrays,
                          axis=sum_axes,
                          sumaxis=True,
                          keepdims=True) / r
    # Remove extra axes
    m = misc.squeeze_to_dim(m, len(shape_parent))
    return m
Exemple #4
0
def message_sum_multiply(plates_parent, dims_parent, *arrays):
    """
    Compute message to parent and sum over plates.

    Divide by the plate multiplier.
    """
    # The shape of the full message
    shapes = [np.shape(array) for array in arrays]
    shape_full = misc.broadcasted_shape(*shapes)
    # Find axes that should be summed
    shape_parent = plates_parent + dims_parent
    sum_axes = misc.axes_to_collapse(shape_full, shape_parent)
    # Compute the multiplier for cancelling the
    # plate-multiplier.  Because we are summing over the
    # dimensions already in this function (for efficiency), we
    # need to cancel the effect of the plate-multiplier
    # applied in the message_to_parent function.
    r = 1
    for j in sum_axes:
        if j >= 0 and j < len(plates_parent):
            r *= shape_full[j]
        elif j < 0 and j < -len(dims_parent):
            r *= shape_full[j]
    # Compute the sum-product
    m = misc.sum_multiply(*arrays, axis=sum_axes, sumaxis=True,
                          keepdims=True) / r
    # Remove extra axes
    m = misc.squeeze_to_dim(m, len(shape_parent))
    return m
Exemple #5
0
    def _message_from_children(self, u_self=None):
        msg = [np.zeros(shape) for shape in self.dims]
        #msg = [np.array(0.0) for i in range(len(self.dims))]
        isfunction = None
        for (child, index) in self.children:
            m = child._message_to_parent(index, u_parent=u_self)
            if callable(m):
                if isfunction is False:
                    raise NotImplementedError()
                elif isfunction is None:
                    msg = m
                else:

                    def join(m1, m2):
                        return (m1[0] + m2[0], m1[1] + m2[1])

                    msg = lambda x: join(m(x), msg(x))
                    isfunction = True
            else:
                if isfunction is True:
                    raise NotImplementedError()
                else:
                    isfunction = False
                    for i in range(len(self.dims)):
                        if m[i] is not None:
                            # Check broadcasting shapes
                            sh = misc.broadcasted_shape(
                                self.get_shape(i), np.shape(m[i]))
                            try:
                                # Try exploiting broadcasting rules
                                msg[i] += m[i]
                            except ValueError:
                                msg[i] = msg[i] + m[i]

        return msg
Exemple #6
0
    def _message_from_children(self, u_self=None):
        msg = [np.zeros(shape) for shape in self.dims]
        #msg = [np.array(0.0) for i in range(len(self.dims))]
        isfunction = None
        for (child,index) in self.children:
            m = child._message_to_parent(index, u_parent=u_self)
            if callable(m):
                if isfunction is False:
                    raise NotImplementedError()
                elif isfunction is None:
                    msg = m
                else:
                    def join(m1, m2):
                        return (m1[0] + m2[0], m1[1] + m2[1])
                    msg = lambda x: join(m(x), msg(x))
                    isfunction = True
            else:
                if isfunction is True:
                    raise NotImplementedError()
                else:
                    isfunction = False
                    for i in range(len(self.dims)):
                        if m[i] is not None:
                            # Check broadcasting shapes
                            sh = misc.broadcasted_shape(self.get_shape(i), np.shape(m[i]))
                            try:
                                # Try exploiting broadcasting rules
                                msg[i] += m[i]
                            except ValueError:
                                msg[i] = msg[i] + m[i]

        return msg
Exemple #7
0
def _subplots(plotfunc, *args, fig=None, kwargs=None):
    """Create a collection of subplots

    Each subplot is created with the same plotting function.

    Inputs are given as pairs:

    (x, 3), (y, 2), ...

    where x,y,... are the input arrays and 3,2,... are the ndim
    parameters.  The last ndim axes of each array are interpreted as a
    single element to the plotting function.

    All high-level plotting functions should wrap low-level plotting
    functions with this function in order to generate subplots for
    plates.
    """

    if kwargs is None:
        kwargs = {}

    if fig is None:
        fig = plt.gcf()

    # Parse shape and plates of each input array
    shapes = [np.shape(x)[-n:] if n > 0 else ()
              for (x,n) in args]
    plates = [np.shape(x)[:-n] if n > 0 else np.shape(x)
              for (x,n) in args]

    # Get the full grid shape of the subplots
    broadcasted_plates = misc.broadcasted_shape(*plates)

    # Subplot indexing layout
    M = np.prod(broadcasted_plates[-2::-2])
    N = np.prod(broadcasted_plates[-1::-2])
    strides_subplot = [np.prod(broadcasted_plates[(j+2)::2]) * N
                       if ((len(broadcasted_plates)-j) % 2) == 0 else
                       np.prod(broadcasted_plates[(j+2)::2])
                       for j in range(len(broadcasted_plates))]

    # Plot each subplot
    for ind in misc.nested_iterator(broadcasted_plates):

        # Get the list of inputs for this subplot
        broadcasted_args = []
        for n in range(len(args)):
            i = misc.safe_indices(ind, plates[n])
            broadcasted_args.append(args[n][0][i])

        # Plot the subplot using the given function
        ind_subplot = np.einsum('i,i', ind, strides_subplot)
        axes = fig.add_subplot(M, N, ind_subplot+1)
        plotfunc(*broadcasted_args, axes=axes, **kwargs)
Exemple #8
0
def _subplots(plotfunc, *args, fig=None, kwargs=None):
    """Create a collection of subplots

    Each subplot is created with the same plotting function.

    Inputs are given as pairs:

    (x, 3), (y, 2), ...

    where x,y,... are the input arrays and 3,2,... are the ndim
    parameters.  The last ndim axes of each array are interpreted as a
    single element to the plotting function.

    All high-level plotting functions should wrap low-level plotting
    functions with this function in order to generate subplots for
    plates.
    """

    if kwargs is None:
        kwargs = {}

    if fig is None:
        fig = plt.gcf()

    # Parse shape and plates of each input array
    shapes = [np.shape(x)[-n:] if n > 0 else () for (x, n) in args]
    plates = [np.shape(x)[:-n] if n > 0 else np.shape(x) for (x, n) in args]

    # Get the full grid shape of the subplots
    broadcasted_plates = misc.broadcasted_shape(*plates)

    # Subplot indexing layout
    M = np.prod(broadcasted_plates[-2::-2])
    N = np.prod(broadcasted_plates[-1::-2])
    strides_subplot = [
        np.prod(broadcasted_plates[(j + 2)::2]) * N if
        ((len(broadcasted_plates) - j) %
         2) == 0 else np.prod(broadcasted_plates[(j + 2)::2])
        for j in range(len(broadcasted_plates))
    ]

    # Plot each subplot
    for ind in misc.nested_iterator(broadcasted_plates):

        # Get the list of inputs for this subplot
        broadcasted_args = []
        for n in range(len(args)):
            i = misc.safe_indices(ind, plates[n])
            broadcasted_args.append(args[n][0][i])

        # Plot the subplot using the given function
        ind_subplot = np.einsum('i,i', ind, strides_subplot)
        axes = fig.add_subplot(M, N, ind_subplot + 1)
        plotfunc(*broadcasted_args, axes=axes, **kwargs)
Exemple #9
0
    def _plate_multiplier(plates, *args):
        """
        Compute the plate multiplier for given shapes.

        The first shape is compared to all other shapes (using NumPy
        broadcasting rules). All the elements which are non-unit in the first
        shape but 1 in all other shapes are multiplied together.

        This method is used, for instance, for computing a correction factor for
        messages to parents: If this node has non-unit plates that are unit
        plates in the parent, those plates are summed. However, if the message
        has unit axis for that plate, it should be first broadcasted to the
        plates of this node and then summed to the plates of the parent. In
        order to avoid this broadcasting and summing, it is more efficient to
        just multiply by the correct factor. This method computes that
        factor. The first argument is the full plate shape of this node (with
        respect to the parent). The other arguments are the shape of the message
        array and the plates of the parent (with respect to this node).
        """
        
        # Check broadcasting of the shapes
        for arg in args:
            misc.broadcasted_shape(plates, arg)

        # Check that each arg-plates are a subset of plates?
        for arg in args:
            if not misc.is_shape_subset(arg, plates):
                raise ValueError("The shapes in args are not a sub-shape of "
                                 "plates.")
            
        r = 1
        for j in range(-len(plates),0):
            mult = True
            for arg in args:
                # if -j <= len(arg) and arg[j] != 1:
                if not (-j > len(arg) or arg[j] == 1):
                    mult = False
            if mult:
                r *= plates[j]
        return r
Exemple #10
0
 def _total_plates(cls, plates, *parent_plates):
     if plates is None:
         # By default, use the minimum number of plates determined
         # from the parent nodes
         try:
             return misc.broadcasted_shape(*parent_plates)
         except ValueError:
             raise ValueError("The plates of the parents do not broadcast.")
     else:
         # Check that the parent_plates are a subset of plates.
         for (ind, p) in enumerate(parent_plates):
             if not misc.is_shape_subset(p, plates):
                 raise ValueError("The plates %s of the parents "
                                  "are not broadcastable to the given "
                                  "plates %s." % (p, plates))
         return plates
Exemple #11
0
    def _message_from_children(self):
        msg = [np.zeros(shape) for shape in self.dims]
        #msg = [np.array(0.0) for i in range(len(self.dims))]
        for (child,index) in self.children:
            m = child._message_to_parent(index)
            for i in range(len(self.dims)):
                if m[i] is not None:
                    # Check broadcasting shapes
                    sh = misc.broadcasted_shape(self.get_shape(i), np.shape(m[i]))
                    try:
                        # Try exploiting broadcasting rules
                        msg[i] += m[i]
                    except ValueError:
                        msg[i] = msg[i] + m[i]

        return msg
Exemple #12
0
    def _message_from_children(self):
        msg = [np.zeros(shape) for shape in self.dims]
        #msg = [np.array(0.0) for i in range(len(self.dims))]
        for (child,index) in self.children:
            m = child._message_to_parent(index)
            for i in range(len(self.dims)):
                if m[i] is not None:
                    # Check broadcasting shapes
                    sh = misc.broadcasted_shape(self.get_shape(i), np.shape(m[i]))
                    try:
                        # Try exploiting broadcasting rules
                        msg[i] += m[i]
                    except ValueError:
                        msg[i] = msg[i] + m[i]

        return msg
Exemple #13
0
 def _total_plates(cls, plates, *parent_plates):
     if plates is None:
         # By default, use the minimum number of plates determined
         # from the parent nodes
         try:
             return misc.broadcasted_shape(*parent_plates)
         except ValueError:
             raise ValueError("The plates of the parents do not broadcast.")
     else:
         # Check that the parent_plates are a subset of plates.
         for (ind, p) in enumerate(parent_plates):
             if not misc.is_shape_subset(p, plates):
                 raise ValueError("The plates %s of the parents "
                                  "are not broadcastable to the given "
                                  "plates %s."
                                  % (p,
                                     plates))
         return plates
    def _compute_moments(self, u_Z):
        """
        Compute the moments given the moments of the parents.
        """
        # Add time axis to p0
        p0 = u_Z[0][..., None, :]
        # Sum joint probability arrays to marginal probability vectors
        zz = u_Z[1]
        p = np.sum(zz, axis=-2)

        # Broadcast p0 and p to same shape, except the time axis
        plates_p0 = np.shape(p0)[:-2]
        plates_p = np.shape(p)[:-2]
        shape = misc.broadcasted_shape(plates_p0, plates_p) + (1, 1)
        p0 = p0 * np.ones(shape)
        p = p * np.ones(shape)

        # Concatenate
        P = np.concatenate((p0, p), axis=-2)

        return [P]
Exemple #15
0
    def _compute_moments(self, u_Z):
        """
        Compute the moments given the moments of the parents.
        """
        # Add time axis to p0
        p0 = u_Z[0][..., None, :]
        # Sum joint probability arrays to marginal probability vectors
        zz = u_Z[1]
        p = np.sum(zz, axis=-2)

        # Broadcast p0 and p to same shape, except the time axis
        plates_p0 = np.shape(p0)[:-2]
        plates_p = np.shape(p)[:-2]
        shape = misc.broadcasted_shape(plates_p0, plates_p) + (1, 1)
        p0 = p0 * np.ones(shape)
        p = p * np.ones(shape)

        # Concatenate
        P = np.concatenate((p0, p), axis=-2)

        return [P]
Exemple #16
0
    def _message_to_parent(self, index, u_parent=None):
        """
        Compute the message and mask to a parent node.
        """

        if self.is_constant[index]:
            raise NotImplementedError(
                "Message to DeltaMoments parent not yet implemented."
            )

        # Check index
        if index >= len(self.parents):
            raise ValueError("Parent index larger than the number of parents")

        # Get messages from other parents and children
        u_parents = self._message_from_parents(exclude=index)
        m = self._message_from_children()
        mask = self.mask

        # Normally we don't need to care about masks when computing the
        # message. However, in this node we want to avoid computing huge message
        # arrays so we sum some axes already here. Thus, we need to apply the
        # mask.
        #
        # Actually, we don't need to care about masks because the message from
        # children has already been masked.

        parent = self.parents[index]

        #
        # Compute the first message
        #

        msg = [None, None]

        # Compute the two messages
        for ind in range(2):

            # The total number of keys for the non-plate dimensions
            N = (ind+1) * self.N_keys

            parent_num_dims = len(parent.dims[ind])
            parent_num_plates = len(parent.plates)
            parent_plate_keys = list(range(N + parent_num_plates,
                                           N,
                                           -1))
            parent_dim_keys = self.in_keys[index]
            if ind == 1:
                parent_dim_keys = ([key + self.N_keys
                                    for key in self.in_keys[index]]
                                   + parent_dim_keys)
            args = []

            # This variable counts the maximum number of plates of the
            # arguments, thus it will tell the number of plates in the result
            # (if the artificially added plates above were ignored).
            result_num_plates = 0
            result_plates = ()

            # Mask and its keysr
            mask_num_plates = np.ndim(mask)
            mask_plates = np.shape(mask)
            mask_plate_keys = list(range(N + mask_num_plates,
                                         N,
                                         -1))
            result_num_plates = max(result_num_plates,
                                    mask_num_plates)
            result_plates = misc.broadcasted_shape(result_plates,
                                                   mask_plates)

            # Moments and keys of other parents
            for (k, u) in enumerate(u_parents):
                if k != index:
                    num_dims = (
                        (ind+1) * len(self.in_keys[k])
                        if not self.is_constant[k] else
                        len(self.in_keys[k])
                    )
                    ui = (
                        u[ind] if not self.is_constant[k] else
                        u[0]
                    )
                    num_plates = np.ndim(ui) - num_dims
                    plates = np.shape(ui)[:num_plates]
                    plate_keys = list(range(N + num_plates,
                                            N,
                                            -1))
                    if ind == 0:
                        args.append(ui)
                        args.append(plate_keys + self.in_keys[k])
                    else:
                        in_keys2 = [key + self.N_keys for key in self.in_keys[k]]
                        if not self.is_constant[k]:
                            # Gaussian moments: Use second moment once
                            args.append(ui)
                            args.append(plate_keys + in_keys2 + self.in_keys[k])
                        else:
                            # Delta moments: Use first moment twice
                            args.append(ui)
                            args.append(plate_keys + self.in_keys[k])
                            args.append(ui)
                            args.append(plate_keys + in_keys2)

                    result_num_plates = max(result_num_plates, num_plates)
                    result_plates = misc.broadcasted_shape(result_plates,
                                                           plates)

            # Message and keys from children
            child_num_dims = (ind+1) * len(self.out_keys)
            child_num_plates = np.ndim(m[ind]) - child_num_dims
            child_plates = np.shape(m[ind])[:child_num_plates]
            child_plate_keys = list(range(N + child_num_plates,
                                          N,
                                          -1))
            child_dim_keys = self.out_keys
            if ind == 1:
                child_dim_keys = ([key + self.N_keys
                                   for key in self.out_keys]
                                  + child_dim_keys)
            args.append(m[ind])
            args.append(child_plate_keys + child_dim_keys)

            result_num_plates = max(result_num_plates, child_num_plates)
            result_plates = misc.broadcasted_shape(result_plates,
                                                   child_plates)

            # Output keys, that is, the keys of the parent[index]
            parent_keys = parent_plate_keys + parent_dim_keys

            # Performance trick: Check which axes can be summed because they
            # have length 1 or are non-existing in parent[index]. Thus, remove
            # keys corresponding to unit length axes in parent[index] so that
            # einsum sums over those axes. After computations, these axes must
            # be added back in order to get the correct shape for the message.
            # Also, remove axes/keys that are in output (parent[index]) but not in
            # any inputs (children and other parents).

            parent_shape = parent.get_shape(ind)
            removed_axes = []
            for j in range(len(parent_keys)):
                if parent_shape[j] == 1:
                    # Remove the key (take into account the number of keys that
                    # have already been removed)
                    del parent_keys[j-len(removed_axes)]
                    removed_axes.append(j)
                else:
                    # Remove the key if it doesn't appear in any of the
                    # messages from children or other parents.
                    if not np.any([parent_keys[j-len(removed_axes)] in keys
                                   for keys in args[1::2]]):
                        del parent_keys[j-len(removed_axes)]
                        removed_axes.append(j)

            args.append(parent_keys)

            # THE BEEF: Compute the message
            msg[ind] = np.einsum(*args)

            # Find the correct shape for the message array
            message_shape = list(np.shape(msg[ind]))
            # First, add back the axes with length 1
            for ax in removed_axes:
                message_shape.insert(ax, 1)
            # Second, remove leading axes for plates that were not present in
            # the child nor other parents' messages. This is not really
            # necessary, but it is just elegant to remove the leading unit
            # length axes that we added artificially at the beginning just
            # because we wanted the key mapping to be simple.
            if parent_num_plates > result_num_plates:
                del message_shape[:(parent_num_plates-result_num_plates)]
            # Then, the actual reshaping
            msg[ind] = np.reshape(msg[ind], message_shape)

            # Broadcasting is not supported for variable dimensions, thus force
            # explicit correct shape for variable dimensions
            var_dims = parent.dims[ind]
            msg[ind] = msg[ind] * np.ones(var_dims)

            # Apply plate multiplier: If this node has non-unit plates that are
            # unit plates in the parent, those plates are summed. However, if
            # the message has unit axis for that plate, it should be first
            # broadcasted to the plates of this node and then summed to the
            # plates of the parent. In order to avoid this broadcasting and
            # summing, it is more efficient to just multiply by the correct
            # factor.
            r = self.broadcasting_multiplier(self.plates,
                                             result_plates,
                                             parent.plates)
            if r != 1:
                msg[ind] *= r

        if self.gaussian_gamma:
            alphas = [
                (u_parents[i][2] if not is_const else 1.0)
                for (i, is_const) in zip(range(len(u_parents)), self.is_constant)
                if i != index
            ]
            m2 = self._compute_message(m[2], mask, *alphas,
                                       ndim=0,
                                       plates_from=self.plates,
                                       plates_to=parent.plates)
            m3 = self._compute_message(m[3], mask,
                                       ndim=0,
                                       plates_from=self.plates,
                                       plates_to=parent.plates)

            msg = msg + [m2, m3]

        return msg
Exemple #17
0
    def _message_to_parent(self, index):
        """
        Compute the message and mask to a parent node.
        """

        # Check index
        if index >= len(self.parents):
            raise ValueError("Parent index larger than the number of parents")

        # Get messages from other parents and children
        u_parents = self._message_from_parents(exclude=index)
        m = self._message_from_children()
        mask = self.mask

        # Normally we don't need to care about masks when computing the
        # message. However, in this node we want to avoid computing huge message
        # arrays so we sum some axes already here. Thus, we need to apply the
        # mask.

        parent = self.parents[index]

        #
        # Compute the first message
        #

        msg = [None, None]
        
        # Compute the two messages
        for ind in range(2):

            # The total number of keys for the non-plate dimensions
            N = (ind+1) * self.N_keys

            # Add an array of ones to ensure proper shape and number of
            # plates. Note that this adds an axis for each plate. At the end, we
            # want to remove axes that were created only because of this
            parent_num_dims = len(parent.dims[ind])
            parent_num_plates = len(parent.plates)
            parent_plate_keys = list(range(N + parent_num_plates,
                                           N,
                                           -1))
            parent_dim_keys = self.in_keys[index]
            if ind == 1:
                parent_dim_keys = ([key + self.N_keys
                                    for key in self.in_keys[index]]
                                   + parent_dim_keys)
            args = []
            args.append(np.ones((1,)*parent_num_plates + parent.dims[ind]))
            args.append(parent_plate_keys + parent_dim_keys)

            # This variable counts the maximum number of plates of the
            # arguments, thus it will tell the number of plates in the result
            # (if the artificially added plates above were ignored).
            result_num_plates = 0
            result_plates = ()

            # Mask and its keysr
            mask_num_plates = np.ndim(mask)
            mask_plates = np.shape(mask)
            mask_plate_keys = list(range(N + mask_num_plates, 
                                         N,
                                         -1))
            result_num_plates = max(result_num_plates,
                                    mask_num_plates)
            result_plates = misc.broadcasted_shape(result_plates,
                                                   mask_plates)
            args.append(mask)
            args.append(mask_plate_keys)

            # Moments and keys of other parents
            for (k, u) in enumerate(u_parents):
                if k != index:
                    num_dims = (ind+1) * len(self.in_keys[k])
                    num_plates = np.ndim(u[ind]) - num_dims
                    plates = np.shape(u[ind])[:num_plates]
                    plate_keys = list(range(N + num_plates, 
                                            N,
                                            -1))
                    dim_keys = self.in_keys[k]
                    if ind == 1:
                        dim_keys = ([key + self.N_keys 
                                     for key in self.in_keys[k]]
                                    + dim_keys)
                    args.append(u[ind])
                    args.append(plate_keys + dim_keys)

                    result_num_plates = max(result_num_plates, num_plates)
                    result_plates = misc.broadcasted_shape(result_plates,
                                                           plates)

            # Message and keys from children
            child_num_dims = (ind+1) * len(self.out_keys)
            child_num_plates = np.ndim(m[ind]) - child_num_dims
            child_plates = np.shape(m[ind])[:child_num_plates]
            child_plate_keys = list(range(N + child_num_plates,
                                          N,
                                          -1))
            child_dim_keys = self.out_keys
            if ind == 1:
                child_dim_keys = ([key + self.N_keys
                                   for key in self.out_keys]
                                  + child_dim_keys)
            args.append(m[ind])
            args.append(child_plate_keys + child_dim_keys)

            result_num_plates = max(result_num_plates, child_num_plates)
            result_plates = misc.broadcasted_shape(result_plates,
                                                   child_plates)

            # Output keys, that is, the keys of the parent[index]
            parent_keys = parent_plate_keys + parent_dim_keys

            # Performance trick: Check which axes can be summed because they
            # have length 1 or are non-existing in parent[index]. Thus, remove
            # keys corresponding to unit length axes in parent[index] so that
            # einsum sums over those axes. After computations, these axes must
            # be added back in order to get the correct shape for the message.

            parent_shape = parent.get_shape(ind)
            removed_axes = []
            for j in range(len(parent_keys)):
                if parent_shape[j] == 1:
                    # Remove the key (take into account the number of keys that
                    # have already been removed)
                    del parent_keys[j-len(removed_axes)]
                    removed_axes.append(j)

            args.append(parent_keys)

            # THE BEEF: Compute the message
            msg[ind] = np.einsum(*args)

            # Find the correct shape for the message array
            message_shape = list(np.shape(msg[ind]))
            # First, add back the axes with length 1
            for ax in removed_axes:
                message_shape.insert(ax, 1)
            # Second, remove leading axes for plates that were not present in
            # the child nor other parents' messages. This is not really
            # necessary, but it is just elegant to remove the leading unit
            # length axes that we added artificially at the beginning just
            # because we wanted the key mapping to be simple.
            if parent_num_plates > result_num_plates:
                del message_shape[:(parent_num_plates-result_num_plates)]
            # Then, the actual reshaping
            msg[ind] = np.reshape(msg[ind], message_shape)

            # Apply plate multiplier: If this node has non-unit plates that are
            # unit plates in the parent, those plates are summed. However, if
            # the message has unit axis for that plate, it should be first
            # broadcasted to the plates of this node and then summed to the
            # plates of the parent. In order to avoid this broadcasting and
            # summing, it is more efficient to just multiply by the correct
            # factor.
            r = self.broadcasting_multiplier(self.plates, 
                                             result_plates,
                                             parent.plates)
            if r != 1:
                msg[ind] *= r

        if self.gaussian_gamma:
            alphas = [u_parents[i][2] 
                      for i in range(len(u_parents)) if i != index]
            ## logalphas = [u_parents[i][3]
            ##              for i in range(len(u_parents)) if i != index]
            m2 = self._compute_message(mask, m[2], *alphas,
                                       ndim=0,
                                       plates_from=self.plates,
                                       plates_to=parent.plates)
            m3 = self._compute_message(mask, m[3],#*logalphas,
                                       ndim=0,
                                       plates_from=self.plates,
                                       plates_to=parent.plates)

            msg = msg + [m2, m3]
            
        return msg
Exemple #18
0
    def _compute_message(*arrays, plates_from=(), plates_to=(), ndim=0):
        """
        A general function for computing messages by sum-multiply

        The function computes the product of the input arrays and then sums to
        the requested plates.
        """

        # Check that the plates broadcast properly
        if not misc.is_shape_subset(plates_to, plates_from):
            raise ValueError("plates_to must be broadcastable to plates_from")

        # Compute the explicit shape of the product
        shapes = [np.shape(array) for array in arrays]
        arrays_shape = misc.broadcasted_shape(*shapes)

        # Compute plates and dims that are present
        if ndim == 0:
            arrays_plates = arrays_shape
            dims = ()
        else:
            arrays_plates = arrays_shape[:-ndim]
            dims = arrays_shape[-ndim:]

        # Compute the correction term.  If some of the plates that should be
        # summed are actually broadcasted, one must multiply by the size of the
        # corresponding plate
        r = Node.broadcasting_multiplier(plates_from, arrays_plates, plates_to)

        # For simplicity, make the arrays equal ndim
        arrays = misc.make_equal_ndim(*arrays)
        
        # Keys for the input plates: (N-1, N-2, ..., 0)
        nplates = len(arrays_plates)
        in_plate_keys = list(range(nplates))

        # Keys for the output plates
        out_plate_keys = [key 
                          for key in in_plate_keys
                          if key < len(plates_to) and plates_to[-key-1] != 1]

        # Keys for the dims
        dim_keys = list(range(nplates, nplates+ndim))

        # Total input and output keys
        in_keys = len(arrays) * [in_plate_keys + dim_keys]
        out_keys = out_plate_keys + dim_keys

        # Compute the sum-product with correction
        einsum_args = misc.zipper_merge(arrays, in_keys) + [out_keys]
        y = r * np.einsum(*einsum_args)

        # Reshape the result and apply correction
        nplates_result = min(len(plates_to), len(arrays_plates))
        if nplates_result == 0:
            plates_result = []
        else:
            plates_result = [min(plates_to[ind], arrays_plates[ind])
                             for ind in range(-nplates_result, 0)]
        y = np.reshape(y, plates_result + list(dims))

        return y
Exemple #19
0
    def _message_to_parent(self, index):

        # Compute the message, check plates, apply mask and sum over some plates
        if index >= len(self.parents):
            raise ValueError("Parent index larger than the number of parents")

        # Compute the message and mask
        (m, mask) = self._get_message_and_mask_to_parent(index)
        mask = misc.squeeze(mask)

        # Plates in the mask
        plates_mask = np.shape(mask)

        # The parent we're sending the message to
        parent = self.parents[index]

        # Plates with respect to the parent
        plates_self = self._plates_to_parent(index)

        # Plate multiplier of the parent
        multiplier_parent = self._plates_multiplier_from_parent(index)

        # Check if m is a logpdf function (for black-box variational inference)
        if callable(m):
            def m_function(*args):
                lpdf = m(*args)
                # Log pdf only contains plate axes!
                plates_m = np.shape(lpdf)
                r = (self.broadcasting_multiplier(plates_self,
                                                  plates_m,
                                                  plates_mask,
                                                  parent.plates) *
                     self.broadcasting_multiplier(self.plates_multiplier,
                                                  multiplier_parent))
                axes_msg = misc.axes_to_collapse(plates_m, parent.plates)
                m[i] = misc.sum_multiply(mask_i, m[i], r,
                                         axis=axes_msg,
                                         keepdims=True)

                # Remove leading singular plates if the parent does not have
                # those plate axes.
                m[i] = misc.squeeze_to_dim(m[i], len(shape_parent))

            return m_function
            raise NotImplementedError()

        # Compact the message to a proper shape
        for i in range(len(m)):

            # Empty messages are given as None. We can ignore those.
            if m[i] is not None:

                try:
                    r = self.broadcasting_multiplier(self.plates_multiplier,
                                                     multiplier_parent)
                except:
                    raise ValueError("The plate multipliers are incompatible. "
                                     "This node (%s) has %s and parent[%d] "
                                     "(%s) has %s"
                                     % (self.name,
                                        self.plates_multiplier,
                                        index,
                                        parent.name,
                                        multiplier_parent))

                ndim = len(parent.dims[i])
                # Source and target shapes
                if ndim > 0:
                    dims = misc.broadcasted_shape(np.shape(m[i])[-ndim:],
                                                  parent.dims[i])
                    from_shape = plates_self + dims
                else:
                    from_shape = plates_self
                to_shape = parent.get_shape(i)
                # Add variable axes to the mask
                mask_i = misc.add_trailing_axes(mask, ndim)
                # Apply mask and sum plate axes as necessary (and apply plate
                # multiplier)
                m[i] = r * misc.sum_multiply_to_plates(m[i], mask_i,
                                                       to_plates=to_shape,
                                                       from_plates=from_shape,
                                                       ndim=0)

        return m
Exemple #20
0
    def _message_to_parent(self, index):

        # Compute the message, check plates, apply mask and sum over some plates
        if index >= len(self.parents):
            raise ValueError("Parent index larger than the number of parents")

        # Compute the message and mask
        (m, mask) = self._get_message_and_mask_to_parent(index)
        mask = misc.squeeze(mask)

        # Plates in the mask
        plates_mask = np.shape(mask)

        # The parent we're sending the message to
        parent = self.parents[index]

        # Plates with respect to the parent
        plates_self = self._plates_to_parent(index)

        # Plate multiplier of the parent
        multiplier_parent = self._plates_multiplier_from_parent(index)

        # Check if m is a logpdf function (for black-box variational inference)
        if callable(m):

            def m_function(*args):
                lpdf = m(*args)
                # Log pdf only contains plate axes!
                plates_m = np.shape(lpdf)
                r = (self.broadcasting_multiplier(plates_self, plates_m,
                                                  plates_mask, parent.plates) *
                     self.broadcasting_multiplier(self.plates_multiplier,
                                                  multiplier_parent))
                axes_msg = misc.axes_to_collapse(plates_m, parent.plates)
                m[i] = misc.sum_multiply(mask_i,
                                         m[i],
                                         r,
                                         axis=axes_msg,
                                         keepdims=True)

                # Remove leading singular plates if the parent does not have
                # those plate axes.
                m[i] = misc.squeeze_to_dim(m[i], len(shape_parent))

            return m_function
            raise NotImplementedError()

        # Compact the message to a proper shape
        for i in range(len(m)):

            # Empty messages are given as None. We can ignore those.
            if m[i] is not None:

                try:
                    r = self.broadcasting_multiplier(self.plates_multiplier,
                                                     multiplier_parent)
                except:
                    raise ValueError("The plate multipliers are incompatible. "
                                     "This node (%s) has %s and parent[%d] "
                                     "(%s) has %s" %
                                     (self.name, self.plates_multiplier, index,
                                      parent.name, multiplier_parent))

                ndim = len(parent.dims[i])
                # Source and target shapes
                if ndim > 0:
                    dims = misc.broadcasted_shape(
                        np.shape(m[i])[-ndim:], parent.dims[i])
                    from_shape = plates_self + dims
                else:
                    from_shape = plates_self
                to_shape = parent.get_shape(i)
                # Add variable axes to the mask
                mask_i = misc.add_trailing_axes(mask, ndim)
                # Apply mask and sum plate axes as necessary (and apply plate
                # multiplier)
                m[i] = r * misc.sum_multiply_to_plates(m[i],
                                                       mask_i,
                                                       to_plates=to_shape,
                                                       from_plates=from_shape,
                                                       ndim=0)

        return m
Exemple #21
0
    def _compute_message(*arrays, plates_from=(), plates_to=(), ndim=0):
        """
        A general function for computing messages by sum-multiply

        The function computes the product of the input arrays and then sums to
        the requested plates.
        """

        # Check that the plates broadcast properly
        if not misc.is_shape_subset(plates_to, plates_from):
            raise ValueError("plates_to must be broadcastable to plates_from")

        # Compute the explicit shape of the product
        shapes = [np.shape(array) for array in arrays]
        arrays_shape = misc.broadcasted_shape(*shapes)

        # Compute plates and dims that are present
        if ndim == 0:
            arrays_plates = arrays_shape
            dims = ()
        else:
            arrays_plates = arrays_shape[:-ndim]
            dims = arrays_shape[-ndim:]

        # Compute the correction term.  If some of the plates that should be
        # summed are actually broadcasted, one must multiply by the size of the
        # corresponding plate
        r = Node.broadcasting_multiplier(plates_from, arrays_plates, plates_to)

        # For simplicity, make the arrays equal ndim
        arrays = misc.make_equal_ndim(*arrays)

        # Keys for the input plates: (N-1, N-2, ..., 0)
        nplates = len(arrays_plates)
        in_plate_keys = list(range(nplates))

        # Keys for the output plates
        out_plate_keys = [
            key for key in in_plate_keys
            if key < len(plates_to) and plates_to[-key - 1] != 1
        ]

        # Keys for the dims
        dim_keys = list(range(nplates, nplates + ndim))

        # Total input and output keys
        in_keys = len(arrays) * [in_plate_keys + dim_keys]
        out_keys = out_plate_keys + dim_keys

        # Compute the sum-product with correction
        einsum_args = misc.zipper_merge(arrays, in_keys) + [out_keys]
        y = r * np.einsum(*einsum_args)

        # Reshape the result and apply correction
        nplates_result = min(len(plates_to), len(arrays_plates))
        if nplates_result == 0:
            plates_result = []
        else:
            plates_result = [
                min(plates_to[ind], arrays_plates[ind])
                for ind in range(-nplates_result, 0)
            ]
        y = np.reshape(y, plates_result + list(dims))

        return y
Exemple #22
0
    def _message_to_parent(self, index):

        # Compute the message, check plates, apply mask and sum over some plates
        if index >= len(self.parents):
            raise ValueError("Parent index larger than the number of parents")

        # Compute the message and mask
        (m, mask) = self._get_message_and_mask_to_parent(index)
        mask = misc.squeeze(mask)

        # Plates in the mask
        plates_mask = np.shape(mask)

        # The parent we're sending the message to
        parent = self.parents[index]

        # Compact the message to a proper shape
        for i in range(len(m)):

            # Empty messages are given as None. We can ignore those.
            if m[i] is not None:

                # Plates in the message
                shape_m = np.shape(m[i])
                dim_parent = len(parent.dims[i])
                if dim_parent > 0:
                    plates_m = shape_m[:-dim_parent]
                else:
                    plates_m = shape_m

                # Compute the multiplier (multiply by the number of plates for
                # which the message, the mask and the parent have single
                # plates).  Such a plate is meant to be broadcasted but because
                # the parent has singular plate axis, it won't broadcast (and
                # sum over it), so we need to multiply it.
                plates_self = self._plates_to_parent(index)
                try:
                    r = self._plate_multiplier(plates_self, 
                                               plates_m,
                                               plates_mask,
                                               parent.plates)
                except ValueError:
                    raise ValueError("The plates of the message, the mask and "
                                     "parent[%d] node (%s) are not a "
                                     "broadcastable subset of the plates of "
                                     "this node (%s).  The message has shape "
                                     "%s, meaning plates %s. The mask has "
                                     "plates %s. This node has plates %s with "
                                     "respect to the parent[%d], which has "
                                     "plates %s."
                                     % (index,
                                        parent.name,
                                        self.name,
                                        np.shape(m[i]), 
                                        plates_m, 
                                        plates_mask,
                                        plates_self,
                                        index, 
                                        parent.plates))

                # Add variable axes to the mask
                shape_mask = np.shape(mask) + (1,) * len(parent.dims[i])
                mask_i = np.reshape(mask, shape_mask)

                # Sum over plates that are not in the message nor in the parent
                shape_parent = parent.get_shape(i)
                shape_msg = misc.broadcasted_shape(shape_m, shape_parent)
                axes_mask = misc.axes_to_collapse(shape_mask, shape_msg)
                mask_i = np.sum(mask_i, axis=axes_mask, keepdims=True)

                # Compute the masked message and sum over the plates that the
                # parent does not have.
                axes_msg = misc.axes_to_collapse(shape_msg, shape_parent)
                m[i] = misc.sum_multiply(mask_i, m[i], r, 
                                         axis=axes_msg, 
                                         keepdims=True)

                # Remove leading singular plates if the parent does not have
                # those plate axes.
                m[i] = misc.squeeze_to_dim(m[i], len(shape_parent))

        return m
Exemple #23
0
        def sum_plates(V, *plates):
            full_plates = misc.broadcasted_shape(*plates)

            r = self.node_X.broadcasting_multiplier(full_plates, np.shape(V))
            return r * np.sum(V)