def _mask_to_parent(self, index): """ Get the mask with respect to parent[index]. The mask tells which plate connections are active. The mask is "summed" (logical or) and reshaped into the plate shape of the parent. Thus, it can't be used for masking messages, because some plates have been summed already. This method is used for propagating the mask to parents. """ mask = self._compute_mask_to_parent(index, self.mask) # Check the shape of the mask plates_to_parent = self._plates_to_parent(index) if not utils.is_shape_subset(np.shape(mask), plates_to_parent): raise ValueError("In node %s, the mask being sent to " "parent[%d] (%s) has invalid shape: The shape of " "the mask %s is not a sub-shape of the plates of " "the node with respect to the parent %s. It could " "be that this node (%s) is manipulating plates " "but has not overwritten the method " "_compute_mask_to_parent." % (self.name, index, self.parents[index].name, np.shape(mask), plates_to_parent, self.__class__.__name__)) # "Sum" (i.e., logical or) over the plates that have unit length in # the parent node. parent_plates = self.parents[index].plates s = utils.axes_to_collapse(np.shape(mask), parent_plates) mask = np.any(mask, axis=s, keepdims=True) mask = utils.squeeze_to_dim(mask, len(parent_plates)) return mask
def _update_phi_from_parents(self, *u_parents): # TODO/FIXME: Could this be combined to the function # _update_distribution_and_lowerbound ? # No, because some initialization methods may want to use this. # This makes correct broadcasting self.phi = self._distribution.compute_phi_from_parents(*u_parents) #self.phi = self._compute_phi_from_parents(*u_parents) self.phi = list(self.phi) # Make sure phi has the correct number of axes. It makes life # a bit easier elsewhere. for i in range(len(self.phi)): axes = len(self.plates) + self.ndims[i] - np.ndim(self.phi[i]) if axes > 0: # Add axes self.phi[i] = utils.add_leading_axes(self.phi[i], axes) elif axes < 0: # Remove extra leading axes first = -(len(self.plates)+self.ndims[i]) sh = np.shape(self.phi[i])[first:] self.phi[i] = np.reshape(self.phi[i], sh) # Check that the shape is correct if not utils.is_shape_subset(np.shape(self.phi[i]), self.get_shape(i)): raise ValueError("Incorrect shape of phi[%d] in node class %s. " "Shape is %s but it should be broadcastable " "to shape %s." % (i, self.__class__.__name__, np.shape(self.phi[i]), self.get_shape(i)))
def _update_phi_from_parents(self, *u_parents): # TODO/FIXME: Could this be combined to the function # _update_distribution_and_lowerbound ? # No, because some initialization methods may want to use this. # This makes correct broadcasting self.phi = self._distribution.compute_phi_from_parents(*u_parents) #self.phi = self._compute_phi_from_parents(*u_parents) self.phi = list(self.phi) # Make sure phi has the correct number of axes. It makes life # a bit easier elsewhere. for i in range(len(self.phi)): axes = len(self.plates) + self._distribution.ndims[i] - np.ndim( self.phi[i]) if axes > 0: # Add axes self.phi[i] = utils.add_leading_axes(self.phi[i], axes) elif axes < 0: # Remove extra leading axes first = -(len(self.plates) + self._distribution.ndims[i]) sh = np.shape(self.phi[i])[first:] self.phi[i] = np.reshape(self.phi[i], sh) # Check that the shape is correct if not utils.is_shape_subset(np.shape(self.phi[i]), self.get_shape(i)): raise ValueError("Incorrect shape in phi[%d]. Shape is %s but " "it should be broadcastable to shape %s." % (i, np.shape(self.phi[i]), self.get_shape(i)))
def _message_to_child(self): u = self.get_moments() # Debug: Check that the message has appropriate shape for (ui, dim) in zip(u, self.dims): ndim = len(dim) if ndim > 0: if np.shape(ui)[-ndim:] != dim: raise RuntimeError( "A bug found by _message_to_child for %s: " "The variable axes of the moments %s are not equal to " "the axes %s defined by the node %s. A possible reason " "is that the plates of the node are inferred " "incorrectly from the parents, and the method " "_plates_from_parents should be implemented." % (self.__class__.__name__, np.shape(ui)[-ndim:], dim, self.name)) if not utils.is_shape_subset(np.shape(ui)[:-ndim], self.plates): raise RuntimeError( "A bug found by _message_to_child for %s: " "The plate axes of the moments %s are not a subset of " "the plate axes %s defined by the node %s." % (self.__class__.__name__, np.shape(ui)[:-ndim], self.plates, self.name)) else: if not utils.is_shape_subset(np.shape(ui), self.plates): raise RuntimeError( "A bug found by _message_to_child for %s: " "The plate axes of the moments %s are not a subset of " "the plate axes %s defined by the node %s." % (self.__class__.__name__, np.shape(ui), self.plates, self.name)) return u
def __init__(self, *parents, dims=None, plates=None, name=""): if dims is None: raise Exception("You need to specify the dimensionality of the " "distribution for class %s" % str(self.__class__)) self.dims = dims self.name = name # Parents self.parents = parents # Inform parent nodes for (index,parent) in enumerate(self.parents): if parent: parent._add_child(self, index) # Check plates parent_plates = [self._plates_from_parent(index) for index in range(len(self.parents))] if any(p is None for p in parent_plates): raise ValueError("Method _plates_from_parent returned None") if plates is None: # By default, use the minimum number of plates determined # from the parent nodes try: self.plates = utils.broadcasted_shape(*parent_plates) except ValueError: raise ValueError("The plates of the parents do not broadcast.") else: # Use custom plates self.plates = plates # Check that the parent_plates are a subset of plates. for p in parent_plates: if not utils.is_shape_subset(p, plates): raise ValueError("The plates of the parents are not " "subsets of the given plates.") # By default, ignore all plates self.mask = np.array(False) # Children self.children = list()
def _total_plates(cls, plates, *parent_plates): if plates is None: # By default, use the minimum number of plates determined # from the parent nodes try: return utils.broadcasted_shape(*parent_plates) except ValueError: raise ValueError("The plates of the parents do not broadcast.") else: # Check that the parent_plates are a subset of plates. for (ind, p) in enumerate(parent_plates): if not utils.is_shape_subset(p, plates): raise ValueError("The plates %s of the parents " "are not broadcastable to the given " "plates %s." % (p, plates)) return plates
def _update_mask(self): # Combine masks from children mask = np.array(False) for (child, index) in self.children: mask = np.logical_or(mask, child._mask_to_parent(index)) # Set the mask of this node self._set_mask(mask) if not utils.is_shape_subset(np.shape(self.mask), self.plates): raise ValueError( "The mask of the node %s has updated " "incorrectly. The plates in the mask %s are not a " "subset of the plates of the node %s." % (self.name, np.shape(self.mask), self.plates)) # Tell parents to update their masks for parent in self.parents: parent._update_mask()
def _update_mask(self): # Combine masks from children mask = np.array(False) for (child, index) in self.children: mask = np.logical_or(mask, child._mask_to_parent(index)) # Set the mask of this node self._set_mask(mask) if not utils.is_shape_subset(np.shape(self.mask), self.plates): raise ValueError("The mask of the node %s has updated " "incorrectly. The plates in the mask %s are not a " "subset of the plates of the node %s." % (self.name, np.shape(self.mask), self.plates)) # Tell parents to update their masks for parent in self.parents: parent._update_mask()
def __init__(self, *parents, dims=None, plates=None, name=""): if dims is None: raise Exception("You need to specify the dimensionality of the " "distribution for class %s" % str(self.__class__)) self.dims = dims self.name = name # Parents self.parents = parents # Inform parent nodes for (index, parent) in enumerate(self.parents): if parent: parent._add_child(self, index) # Check plates parent_plates = [ self._plates_from_parent(index) for index in range(len(self.parents)) ] if plates is None: # By default, use the minimum number of plates determined # from the parent nodes try: self.plates = utils.broadcasted_shape(*parent_plates) except ValueError: raise ValueError("The plates of the parents do not broadcast.") else: # Use custom plates self.plates = plates # Check that the parent_plates are a subset of plates. for p in parent_plates: if not utils.is_shape_subset(p, plates): raise ValueError("The plates of the parents are not " "subsets of the given plates.") # By default, ignore all plates self.mask = np.array(False) # Children self.children = list()
def _plate_multiplier(plates, *args): # Check broadcasting of the shapes for arg in args: utils.broadcasted_shape(plates, arg) # Check that each arg-plates are a subset of plates? for arg in args: if not utils.is_shape_subset(arg, plates): raise ValueError("The shapes in args are not a sub-shape of " "plates.") r = 1 for j in range(-len(plates), 0): mult = True for arg in args: if not (-j > len(arg) or arg[j] == 1): mult = False if mult: r *= plates[j] return r
def _plate_multiplier(plates, *args): # Check broadcasting of the shapes for arg in args: utils.broadcasted_shape(plates, arg) # Check that each arg-plates are a subset of plates? for arg in args: if not utils.is_shape_subset(arg, plates): raise ValueError("The shapes in args are not a sub-shape of " "plates.") r = 1 for j in range(-len(plates),0): mult = True for arg in args: if not (-j > len(arg) or arg[j] == 1): mult = False if mult: r *= plates[j] return r
def _plate_multiplier(plates, *args): """ Compute the plate multiplier for given shapes. The first shape is compared to all other shapes (using NumPy broadcasting rules). All the elements which are non-unit in the first shape but 1 in all other shapes are multiplied together. This method is used, for instance, for computing a correction factor for messages to parents: If this node has non-unit plates that are unit plates in the parent, those plates are summed. However, if the message has unit axis for that plate, it should be first broadcasted to the plates of this node and then summed to the plates of the parent. In order to avoid this broadcasting and summing, it is more efficient to just multiply by the correct factor. This method computes that factor. The first argument is the full plate shape of this node (with respect to the parent). The other arguments are the shape of the message array and the plates of the parent (with respect to this node). """ # Check broadcasting of the shapes for arg in args: utils.broadcasted_shape(plates, arg) # Check that each arg-plates are a subset of plates? for arg in args: if not utils.is_shape_subset(arg, plates): raise ValueError("The shapes in args are not a sub-shape of " "plates.") r = 1 for j in range(-len(plates),0): mult = True for arg in args: # if -j <= len(arg) and arg[j] != 1: if not (-j > len(arg) or arg[j] == 1): mult = False if mult: r *= plates[j] return r