Example #1
0
    def _check_shape(self, u, broadcast=True):

        if len(u) != len(self.dims):
            raise ValueError("Incorrect number of arrays")

        for (dimsi, ui) in zip(self.dims, u):
            sh_true = self.plates + dimsi
            sh = np.shape(ui)
            ndim = len(dimsi)
            errmsg = (
                "Shape of the given array not equal to the shape of the node.\n"
                "Received shape: {0}\n"
                "Expected shape: {1}\n"
                "Check plates."
                .format(sh, sh_true)
            )
            if not broadcast:
                if sh != sh_true:
                    raise ValueError(errmsg)
            else:
                if ndim == 0:
                    if not misc.is_shape_subset(sh, sh_true):
                        raise ValueError(errmsg)
                else:
                    plates_ok = misc.is_shape_subset(sh[:-ndim], self.plates)
                    dims_ok = (sh[-ndim:] == dimsi)
                    if not (plates_ok and dims_ok):
                        raise ValueError(errmsg)

        return
Example #2
0
 def regularization(self, regularization):
     if len(regularization) != 2:
         raise ValueError("Regularization must 2-tuple")
     if not misc.is_shape_subset(np.shape(regularization[0]), self.get_shape(0)):
         raise ValueError("Wrong shape")
     if not misc.is_shape_subset(np.shape(regularization[1]), self.get_shape(1)):
         raise ValueError("Wrong shape")
     self.__regularization = regularization
     return
Example #3
0
    def _update_phi_from_parents(self, *u_parents):

        # TODO/FIXME: Could this be combined to the function
        # _update_distribution_and_lowerbound ?
        # No, because some initialization methods may want to use this.

        # This makes correct broadcasting
        self.phi = self._distribution.compute_phi_from_parents(*u_parents)
        #self.phi = self._compute_phi_from_parents(*u_parents)
        self.phi = list(self.phi)
        # Make sure phi has the correct number of axes. It makes life
        # a bit easier elsewhere.
        for i in range(len(self.phi)):
            axes = len(self.plates) + self.ndims[i] - np.ndim(self.phi[i])
            if axes > 0:
                # Add axes
                self.phi[i] = misc.add_leading_axes(self.phi[i], axes)
            elif axes < 0:
                # Remove extra leading axes
                first = -(len(self.plates)+self.ndims[i])
                sh = np.shape(self.phi[i])[first:]
                self.phi[i] = np.reshape(self.phi[i], sh)
            # Check that the shape is correct
            if not misc.is_shape_subset(np.shape(self.phi[i]),
                                         self.get_shape(i)):
                raise ValueError("Incorrect shape of phi[%d] in node class %s. "
                                 "Shape is %s but it should be broadcastable "
                                 "to shape %s."
                                 % (i,
                                    self.__class__.__name__,
                                    np.shape(self.phi[i]),
                                    self.get_shape(i)))
Example #4
0
    def _mask_to_parent(self, index):
        """
        Get the mask with respect to parent[index].

        The mask tells which plate connections are active. The mask is "summed"
        (logical or) and reshaped into the plate shape of the parent. Thus, it
        can't be used for masking messages, because some plates have been summed
        already. This method is used for propagating the mask to parents.
        """
        mask = self._compute_mask_to_parent(index, self.mask)

        # Check the shape of the mask
        plates_to_parent = self._plates_to_parent(index)
        if not misc.is_shape_subset(np.shape(mask), plates_to_parent):
            raise ValueError("In node %s, the mask being sent to "
                             "parent[%d] (%s) has invalid shape: The shape of "
                             "the mask %s is not a sub-shape of the plates of "
                             "the node with respect to the parent %s. It could "
                             "be that this node (%s) is manipulating plates "
                             "but has not overwritten the method "
                             "_compute_mask_to_parent."
                             % (self.name,
                                index,
                                self.parents[index].name,
                                np.shape(mask),
                                plates_to_parent,
                                self.__class__.__name__))

        # "Sum" (i.e., logical or) over the plates that have unit length in 
        # the parent node.
        parent_plates = self.parents[index].plates
        s = misc.axes_to_collapse(np.shape(mask), parent_plates)
        mask = np.any(mask, axis=s, keepdims=True)
        mask = misc.squeeze_to_dim(mask, len(parent_plates))
        return mask
Example #5
0
    def _mask_to_parent(self, index):
        """
        Get the mask with respect to parent[index].

        The mask tells which plate connections are active. The mask is "summed"
        (logical or) and reshaped into the plate shape of the parent. Thus, it
        can't be used for masking messages, because some plates have been summed
        already. This method is used for propagating the mask to parents.
        """
        mask = self._compute_mask_to_parent(index, self.mask)

        # Check the shape of the mask
        plates_to_parent = self._plates_to_parent(index)
        if not misc.is_shape_subset(np.shape(mask), plates_to_parent):
            raise ValueError("In node %s, the mask being sent to "
                             "parent[%d] (%s) has invalid shape: The shape of "
                             "the mask %s is not a sub-shape of the plates of "
                             "the node with respect to the parent %s. It could "
                             "be that this node (%s) is manipulating plates "
                             "but has not overwritten the method "
                             "_compute_mask_to_parent."
                             % (self.name,
                                index,
                                self.parents[index].name,
                                np.shape(mask),
                                plates_to_parent,
                                self.__class__.__name__))

        # "Sum" (i.e., logical or) over the plates that have unit length in 
        # the parent node.
        parent_plates = self.parents[index].plates
        s = misc.axes_to_collapse(np.shape(mask), parent_plates)
        mask = np.any(mask, axis=s, keepdims=True)
        mask = misc.squeeze_to_dim(mask, len(parent_plates))
        return mask
Example #6
0
    def _update_phi_from_parents(self, *u_parents):

        # TODO/FIXME: Could this be combined to the function
        # _update_distribution_and_lowerbound ?
        # No, because some initialization methods may want to use this.

        # This makes correct broadcasting
        self.phi = self._distribution.compute_phi_from_parents(*u_parents)
        #self.phi = self._compute_phi_from_parents(*u_parents)
        self.phi = list(self.phi)
        # Make sure phi has the correct number of axes. It makes life
        # a bit easier elsewhere.
        for i in range(len(self.phi)):
            axes = len(self.plates) + self.ndims[i] - np.ndim(self.phi[i])
            if axes > 0:
                # Add axes
                self.phi[i] = misc.add_leading_axes(self.phi[i], axes)
            elif axes < 0:
                # Remove extra leading axes
                first = -(len(self.plates) + self.ndims[i])
                sh = np.shape(self.phi[i])[first:]
                self.phi[i] = np.reshape(self.phi[i], sh)
            # Check that the shape is correct
            if not misc.is_shape_subset(np.shape(self.phi[i]),
                                        self.get_shape(i)):
                raise ValueError(
                    "Incorrect shape of phi[%d] in node class %s. "
                    "Shape is %s but it should be broadcastable "
                    "to shape %s." %
                    (i, self.__class__.__name__, np.shape(
                        self.phi[i]), self.get_shape(i)))
Example #7
0
    def _message_to_child(self):

        u = self.get_moments()
        
        # Debug: Check that the message has appropriate shape
        for (ui, dim) in zip(u, self.dims):
            ndim = len(dim)
            if ndim > 0:
                if np.shape(ui)[-ndim:] != dim:
                    raise RuntimeError(
                        "A bug found by _message_to_child for %s: "
                        "The variable axes of the moments %s are not equal to "
                        "the axes %s defined by the node %s. A possible reason "
                        "is that the plates of the node are inferred "
                        "incorrectly from the parents, and the method "
                        "_plates_from_parents should be implemented."
                        % (self.__class__.__name__,
                           np.shape(ui)[-ndim:],
                           dim,
                           self.name))
                if not misc.is_shape_subset(np.shape(ui)[:-ndim],
                                            self.plates):
                    raise RuntimeError(
                        "A bug found by _message_to_child for %s: "
                        "The plate axes of the moments %s are not a subset of "
                        "the plate axes %s defined by the node %s."
                        % (self.__class__.__name__,
                           np.shape(ui)[:-ndim],
                           self.plates,
                           self.name))
            else:
                if not misc.is_shape_subset(np.shape(ui), self.plates):
                    raise RuntimeError(
                        "A bug found by _message_to_child for %s: "
                        "The plate axes of the moments %s are not a subset of "
                        "the plate axes %s defined by the node %s."
                        % (self.__class__.__name__,
                           np.shape(ui),
                           self.plates,
                           self.name))
        return u
Example #8
0
    def _message_to_child(self):

        u = self.get_moments()
        
        # Debug: Check that the message has appropriate shape
        for (ui, dim) in zip(u, self.dims):
            ndim = len(dim)
            if ndim > 0:
                if np.shape(ui)[-ndim:] != dim:
                    raise RuntimeError(
                        "A bug found by _message_to_child for %s: "
                        "The variable axes of the moments %s are not equal to "
                        "the axes %s defined by the node %s. A possible reason "
                        "is that the plates of the node are inferred "
                        "incorrectly from the parents, and the method "
                        "_plates_from_parents should be implemented."
                        % (self.__class__.__name__,
                           np.shape(ui)[-ndim:],
                           dim,
                           self.name))
                if not misc.is_shape_subset(np.shape(ui)[:-ndim],
                                            self.plates):
                    raise RuntimeError(
                        "A bug found by _message_to_child for %s: "
                        "The plate axes of the moments %s are not a subset of "
                        "the plate axes %s defined by the node %s."
                        % (self.__class__.__name__,
                           np.shape(ui)[:-ndim],
                           self.plates,
                           self.name))
            else:
                if not misc.is_shape_subset(np.shape(ui), self.plates):
                    raise RuntimeError(
                        "A bug found by _message_to_child for %s: "
                        "The plate axes of the moments %s are not a subset of "
                        "the plate axes %s defined by the node %s."
                        % (self.__class__.__name__,
                           np.shape(ui),
                           self.plates,
                           self.name))
        return u
Example #9
0
 def set_value(self, x):
     x = np.asanyarray(x)
     #shapes = [np.shape(ui) for ui in self.u]
     self.u = self._moments.compute_fixed_moments(x)
     for (i, dimsi) in enumerate(self.dims):
         correct_shape = tuple(self.plates) + tuple(dimsi)
         given_shape = np.shape(self.u[i])
         if not misc.is_shape_subset(given_shape, correct_shape):
             raise ValueError(
                 "Incorrect shape {0} for the array, expected {1}".format(
                     given_shape, correct_shape))
     return
Example #10
0
 def set_value(self, x):
     x = np.asanyarray(x)
     #shapes = [np.shape(ui) for ui in self.u]
     self.u = self._moments.compute_fixed_moments(x)
     for (i, dimsi) in enumerate(self.dims):
         correct_shape = tuple(self.plates) + tuple(dimsi)
         given_shape = np.shape(self.u[i])
         if not misc.is_shape_subset(given_shape, correct_shape):
             raise ValueError(
                 "Incorrect shape {0} for the array, expected {1}"
                 .format(given_shape, correct_shape)
             )
     return
Example #11
0
 def _total_plates(cls, plates, *parent_plates):
     if plates is None:
         # By default, use the minimum number of plates determined
         # from the parent nodes
         try:
             return misc.broadcasted_shape(*parent_plates)
         except ValueError:
             raise ValueError("The plates of the parents do not broadcast.")
     else:
         # Check that the parent_plates are a subset of plates.
         for (ind, p) in enumerate(parent_plates):
             if not misc.is_shape_subset(p, plates):
                 raise ValueError("The plates %s of the parents "
                                  "are not broadcastable to the given "
                                  "plates %s." % (p, plates))
         return plates
Example #12
0
    def _update_mask(self):
        # Combine masks from children
        mask = np.array(False)
        for (child, index) in self.children:
            mask = np.logical_or(mask, child._mask_to_parent(index))
        # Set the mask of this node
        self._set_mask(mask)
        if not misc.is_shape_subset(np.shape(self.mask), self.plates):

            raise ValueError(
                "The mask of the node %s has updated "
                "incorrectly. The plates in the mask %s are not a "
                "subset of the plates of the node %s." %
                (self.name, np.shape(self.mask), self.plates))

        # Tell parents to update their masks
        for parent in self.parents:
            parent._update_mask()
Example #13
0
 def _total_plates(cls, plates, *parent_plates):
     if plates is None:
         # By default, use the minimum number of plates determined
         # from the parent nodes
         try:
             return misc.broadcasted_shape(*parent_plates)
         except ValueError:
             raise ValueError("The plates of the parents do not broadcast.")
     else:
         # Check that the parent_plates are a subset of plates.
         for (ind, p) in enumerate(parent_plates):
             if not misc.is_shape_subset(p, plates):
                 raise ValueError("The plates %s of the parents "
                                  "are not broadcastable to the given "
                                  "plates %s."
                                  % (p,
                                     plates))
         return plates
Example #14
0
    def _update_mask(self):
        # Combine masks from children
        mask = np.array(False)
        for (child, index) in self.children:
            mask = np.logical_or(mask, child._mask_to_parent(index))
        # Set the mask of this node
        self._set_mask(mask)
        if not misc.is_shape_subset(np.shape(self.mask), self.plates):

            raise ValueError("The mask of the node %s has updated "
                             "incorrectly. The plates in the mask %s are not a "
                             "subset of the plates of the node %s."
                             % (self.name,
                                np.shape(self.mask),
                                self.plates))
        
        # Tell parents to update their masks
        for parent in self.parents:
            parent._update_mask()
Example #15
0
    def _plate_multiplier(plates, *args):
        """
        Compute the plate multiplier for given shapes.

        The first shape is compared to all other shapes (using NumPy
        broadcasting rules). All the elements which are non-unit in the first
        shape but 1 in all other shapes are multiplied together.

        This method is used, for instance, for computing a correction factor for
        messages to parents: If this node has non-unit plates that are unit
        plates in the parent, those plates are summed. However, if the message
        has unit axis for that plate, it should be first broadcasted to the
        plates of this node and then summed to the plates of the parent. In
        order to avoid this broadcasting and summing, it is more efficient to
        just multiply by the correct factor. This method computes that
        factor. The first argument is the full plate shape of this node (with
        respect to the parent). The other arguments are the shape of the message
        array and the plates of the parent (with respect to this node).
        """
        
        # Check broadcasting of the shapes
        for arg in args:
            misc.broadcasted_shape(plates, arg)

        # Check that each arg-plates are a subset of plates?
        for arg in args:
            if not misc.is_shape_subset(arg, plates):
                raise ValueError("The shapes in args are not a sub-shape of "
                                 "plates.")
            
        r = 1
        for j in range(-len(plates),0):
            mult = True
            for arg in args:
                # if -j <= len(arg) and arg[j] != 1:
                if not (-j > len(arg) or arg[j] == 1):
                    mult = False
            if mult:
                r *= plates[j]
        return r
Example #16
0
    def _compute_message(*arrays, plates_from=(), plates_to=(), ndim=0):
        """
        A general function for computing messages by sum-multiply

        The function computes the product of the input arrays and then sums to
        the requested plates.
        """

        # Check that the plates broadcast properly
        if not misc.is_shape_subset(plates_to, plates_from):
            raise ValueError("plates_to must be broadcastable to plates_from")

        # Compute the explicit shape of the product
        shapes = [np.shape(array) for array in arrays]
        arrays_shape = misc.broadcasted_shape(*shapes)

        # Compute plates and dims that are present
        if ndim == 0:
            arrays_plates = arrays_shape
            dims = ()
        else:
            arrays_plates = arrays_shape[:-ndim]
            dims = arrays_shape[-ndim:]

        # Compute the correction term.  If some of the plates that should be
        # summed are actually broadcasted, one must multiply by the size of the
        # corresponding plate
        r = Node.broadcasting_multiplier(plates_from, arrays_plates, plates_to)

        # For simplicity, make the arrays equal ndim
        arrays = misc.make_equal_ndim(*arrays)
        
        # Keys for the input plates: (N-1, N-2, ..., 0)
        nplates = len(arrays_plates)
        in_plate_keys = list(range(nplates))

        # Keys for the output plates
        out_plate_keys = [key 
                          for key in in_plate_keys
                          if key < len(plates_to) and plates_to[-key-1] != 1]

        # Keys for the dims
        dim_keys = list(range(nplates, nplates+ndim))

        # Total input and output keys
        in_keys = len(arrays) * [in_plate_keys + dim_keys]
        out_keys = out_plate_keys + dim_keys

        # Compute the sum-product with correction
        einsum_args = misc.zipper_merge(arrays, in_keys) + [out_keys]
        y = r * np.einsum(*einsum_args)

        # Reshape the result and apply correction
        nplates_result = min(len(plates_to), len(arrays_plates))
        if nplates_result == 0:
            plates_result = []
        else:
            plates_result = [min(plates_to[ind], arrays_plates[ind])
                             for ind in range(-nplates_result, 0)]
        y = np.reshape(y, plates_result + list(dims))

        return y
Example #17
0
    def _compute_message(*arrays, plates_from=(), plates_to=(), ndim=0):
        """
        A general function for computing messages by sum-multiply

        The function computes the product of the input arrays and then sums to
        the requested plates.
        """

        # Check that the plates broadcast properly
        if not misc.is_shape_subset(plates_to, plates_from):
            raise ValueError("plates_to must be broadcastable to plates_from")

        # Compute the explicit shape of the product
        shapes = [np.shape(array) for array in arrays]
        arrays_shape = misc.broadcasted_shape(*shapes)

        # Compute plates and dims that are present
        if ndim == 0:
            arrays_plates = arrays_shape
            dims = ()
        else:
            arrays_plates = arrays_shape[:-ndim]
            dims = arrays_shape[-ndim:]

        # Compute the correction term.  If some of the plates that should be
        # summed are actually broadcasted, one must multiply by the size of the
        # corresponding plate
        r = Node.broadcasting_multiplier(plates_from, arrays_plates, plates_to)

        # For simplicity, make the arrays equal ndim
        arrays = misc.make_equal_ndim(*arrays)

        # Keys for the input plates: (N-1, N-2, ..., 0)
        nplates = len(arrays_plates)
        in_plate_keys = list(range(nplates))

        # Keys for the output plates
        out_plate_keys = [
            key for key in in_plate_keys
            if key < len(plates_to) and plates_to[-key - 1] != 1
        ]

        # Keys for the dims
        dim_keys = list(range(nplates, nplates + ndim))

        # Total input and output keys
        in_keys = len(arrays) * [in_plate_keys + dim_keys]
        out_keys = out_plate_keys + dim_keys

        # Compute the sum-product with correction
        einsum_args = misc.zipper_merge(arrays, in_keys) + [out_keys]
        y = r * np.einsum(*einsum_args)

        # Reshape the result and apply correction
        nplates_result = min(len(plates_to), len(arrays_plates))
        if nplates_result == 0:
            plates_result = []
        else:
            plates_result = [
                min(plates_to[ind], arrays_plates[ind])
                for ind in range(-nplates_result, 0)
            ]
        y = np.reshape(y, plates_result + list(dims))

        return y