def _mask_to_parent(self, index): """ Get the mask with respect to parent[index]. The mask tells which plate connections are active. The mask is "summed" (logical or) and reshaped into the plate shape of the parent. Thus, it can't be used for masking messages, because some plates have been summed already. This method is used for propagating the mask to parents. """ mask = self._compute_mask_to_parent(index, self.mask) # Check the shape of the mask plates_to_parent = self._plates_to_parent(index) if not misc.is_shape_subset(np.shape(mask), plates_to_parent): raise ValueError("In node %s, the mask being sent to " "parent[%d] (%s) has invalid shape: The shape of " "the mask %s is not a sub-shape of the plates of " "the node with respect to the parent %s. It could " "be that this node (%s) is manipulating plates " "but has not overwritten the method " "_compute_mask_to_parent." % (self.name, index, self.parents[index].name, np.shape(mask), plates_to_parent, self.__class__.__name__)) # "Sum" (i.e., logical or) over the plates that have unit length in # the parent node. parent_plates = self.parents[index].plates s = misc.axes_to_collapse(np.shape(mask), parent_plates) mask = np.any(mask, axis=s, keepdims=True) mask = misc.squeeze_to_dim(mask, len(parent_plates)) return mask
def message_sum_multiply(plates_parent, dims_parent, *arrays): """ Compute message to parent and sum over plates. Divide by the plate multiplier. """ # The shape of the full message shapes = [np.shape(array) for array in arrays] shape_full = misc.broadcasted_shape(*shapes) # Find axes that should be summed shape_parent = plates_parent + dims_parent sum_axes = misc.axes_to_collapse(shape_full, shape_parent) # Compute the multiplier for cancelling the # plate-multiplier. Because we are summing over the # dimensions already in this function (for efficiency), we # need to cancel the effect of the plate-multiplier # applied in the message_to_parent function. r = 1 for j in sum_axes: if j >= 0 and j < len(plates_parent): r *= shape_full[j] elif j < 0 and j < -len(dims_parent): r *= shape_full[j] # Compute the sum-product m = misc.sum_multiply(*arrays, axis=sum_axes, sumaxis=True, keepdims=True) / r # Remove extra axes m = misc.squeeze_to_dim(m, len(shape_parent)) return m
def _mask_to_parent(self, index): """ Get the mask with respect to parent[index]. The mask tells which plate connections are active. The mask is "summed" (logical or) and reshaped into the plate shape of the parent. Thus, it can't be used for masking messages, because some plates have been summed already. This method is used for propagating the mask to parents. """ mask = self._compute_mask_to_parent(index, self.mask) # Check the shape of the mask plates_to_parent = self._plates_to_parent(index) if not misc.is_shape_subset(np.shape(mask), plates_to_parent): raise ValueError( "In node %s, the mask being sent to " "parent[%d] (%s) has invalid shape: The shape of " "the mask %s is not a sub-shape of the plates of " "the node with respect to the parent %s. It could " "be that this node (%s) is manipulating plates " "but has not overwritten the method " "_compute_mask_to_parent." % (self.name, index, self.parents[index].name, np.shape(mask), plates_to_parent, self.__class__.__name__)) # "Sum" (i.e., logical or) over the plates that have unit length in # the parent node. parent_plates = self.parents[index].plates s = misc.axes_to_collapse(np.shape(mask), parent_plates) mask = np.any(mask, axis=s, keepdims=True) mask = misc.squeeze_to_dim(mask, len(parent_plates)) return mask
def m_function(*args): lpdf = m(*args) # Log pdf only contains plate axes! plates_m = np.shape(lpdf) r = (self.broadcasting_multiplier(plates_self, plates_m, plates_mask, parent.plates) * self.broadcasting_multiplier(self.plates_multiplier, multiplier_parent)) axes_msg = misc.axes_to_collapse(plates_m, parent.plates) m[i] = misc.sum_multiply(mask_i, m[i], r, axis=axes_msg, keepdims=True) # Remove leading singular plates if the parent does not have # those plate axes. m[i] = misc.squeeze_to_dim(m[i], len(shape_parent))
def _message_to_parent(self, index): # Compute the message, check plates, apply mask and sum over some plates if index >= len(self.parents): raise ValueError("Parent index larger than the number of parents") # Compute the message and mask (m, mask) = self._get_message_and_mask_to_parent(index) mask = misc.squeeze(mask) # Plates in the mask plates_mask = np.shape(mask) # The parent we're sending the message to parent = self.parents[index] # Compact the message to a proper shape for i in range(len(m)): # Empty messages are given as None. We can ignore those. if m[i] is not None: # Plates in the message shape_m = np.shape(m[i]) dim_parent = len(parent.dims[i]) if dim_parent > 0: plates_m = shape_m[:-dim_parent] else: plates_m = shape_m # Compute the multiplier (multiply by the number of plates for # which the message, the mask and the parent have single # plates). Such a plate is meant to be broadcasted but because # the parent has singular plate axis, it won't broadcast (and # sum over it), so we need to multiply it. plates_self = self._plates_to_parent(index) try: r = self._plate_multiplier(plates_self, plates_m, plates_mask, parent.plates) except ValueError: raise ValueError("The plates of the message, the mask and " "parent[%d] node (%s) are not a " "broadcastable subset of the plates of " "this node (%s). The message has shape " "%s, meaning plates %s. The mask has " "plates %s. This node has plates %s with " "respect to the parent[%d], which has " "plates %s." % (index, parent.name, self.name, np.shape(m[i]), plates_m, plates_mask, plates_self, index, parent.plates)) # Add variable axes to the mask shape_mask = np.shape(mask) + (1,) * len(parent.dims[i]) mask_i = np.reshape(mask, shape_mask) # Sum over plates that are not in the message nor in the parent shape_parent = parent.get_shape(i) shape_msg = misc.broadcasted_shape(shape_m, shape_parent) axes_mask = misc.axes_to_collapse(shape_mask, shape_msg) mask_i = np.sum(mask_i, axis=axes_mask, keepdims=True) # Compute the masked message and sum over the plates that the # parent does not have. axes_msg = misc.axes_to_collapse(shape_msg, shape_parent) m[i] = misc.sum_multiply(mask_i, m[i], r, axis=axes_msg, keepdims=True) # Remove leading singular plates if the parent does not have # those plate axes. m[i] = misc.squeeze_to_dim(m[i], len(shape_parent)) return m