def _mask_to_parent(self, index): """ Get the mask with respect to parent[index]. The mask tells which plate connections are active. The mask is "summed" (logical or) and reshaped into the plate shape of the parent. Thus, it can't be used for masking messages, because some plates have been summed already. This method is used for propagating the mask to parents. """ mask = self._compute_mask_to_parent(index, self.mask) # Check the shape of the mask plates_to_parent = self._plates_to_parent(index) if not misc.is_shape_subset(np.shape(mask), plates_to_parent): raise ValueError("In node %s, the mask being sent to " "parent[%d] (%s) has invalid shape: The shape of " "the mask %s is not a sub-shape of the plates of " "the node with respect to the parent %s. It could " "be that this node (%s) is manipulating plates " "but has not overwritten the method " "_compute_mask_to_parent." % (self.name, index, self.parents[index].name, np.shape(mask), plates_to_parent, self.__class__.__name__)) # "Sum" (i.e., logical or) over the plates that have unit length in # the parent node. parent_plates = self.parents[index].plates s = misc.axes_to_collapse(np.shape(mask), parent_plates) mask = np.any(mask, axis=s, keepdims=True) mask = misc.squeeze_to_dim(mask, len(parent_plates)) return mask
def _mask_to_parent(self, index): """ Get the mask with respect to parent[index]. The mask tells which plate connections are active. The mask is "summed" (logical or) and reshaped into the plate shape of the parent. Thus, it can't be used for masking messages, because some plates have been summed already. This method is used for propagating the mask to parents. """ mask = self._compute_mask_to_parent(index, self.mask) # Check the shape of the mask plates_to_parent = self._plates_to_parent(index) if not misc.is_shape_subset(np.shape(mask), plates_to_parent): raise ValueError( "In node %s, the mask being sent to " "parent[%d] (%s) has invalid shape: The shape of " "the mask %s is not a sub-shape of the plates of " "the node with respect to the parent %s. It could " "be that this node (%s) is manipulating plates " "but has not overwritten the method " "_compute_mask_to_parent." % (self.name, index, self.parents[index].name, np.shape(mask), plates_to_parent, self.__class__.__name__)) # "Sum" (i.e., logical or) over the plates that have unit length in # the parent node. parent_plates = self.parents[index].plates s = misc.axes_to_collapse(np.shape(mask), parent_plates) mask = np.any(mask, axis=s, keepdims=True) mask = misc.squeeze_to_dim(mask, len(parent_plates)) return mask
def message_sum_multiply(plates_parent, dims_parent, *arrays): """ Compute message to parent and sum over plates. Divide by the plate multiplier. """ # The shape of the full message shapes = [np.shape(array) for array in arrays] shape_full = misc.broadcasted_shape(*shapes) # Find axes that should be summed shape_parent = plates_parent + dims_parent sum_axes = misc.axes_to_collapse(shape_full, shape_parent) # Compute the multiplier for cancelling the # plate-multiplier. Because we are summing over the # dimensions already in this function (for efficiency), we # need to cancel the effect of the plate-multiplier # applied in the message_to_parent function. r = 1 for j in sum_axes: if j >= 0 and j < len(plates_parent): r *= shape_full[j] elif j < 0 and j < -len(dims_parent): r *= shape_full[j] # Compute the sum-product m = misc.sum_multiply(*arrays, axis=sum_axes, sumaxis=True, keepdims=True) / r # Remove extra axes m = misc.squeeze_to_dim(m, len(shape_parent)) return m
def _compute_weights_to_parent(self, index, weights): # Idea: Reshape the message array such that every other axis # will be summed and every other kept. # Make plates equal length plates = self._plates_to_parent(index) shape_m = np.shape(weights) (plates, tiles_m, shape_m) = misc.make_equal_length(plates, tiles, shape_m) # Handle broadcasting rules for axes that have unit length in # the message (although the plate may be non-unit length). Also, # compute the corresponding broadcasting_multiplier. plates = list(plates) tiles_m = list(tiles_m) for j in range(len(plates)): if shape_m[j] == 1: plates[j] = 1 tiles_m[j] = 1 # Combine the tuples by picking every other from tiles_ind and # every other from shape shape = functools.reduce(lambda x, y: x + y, zip(tiles_m, plates)) # ..and reshape the array, that is, every other axis corresponds # to tiles and every other to plates/dimensions in parents weights = np.reshape(weights, shape) # Sum over every other axis axes = tuple(range(0, len(shape), 2)) weights = np.sum(weights, axis=axes) # Remove extra leading axes ndim_parent = len(self.parents[index].plates) weights = misc.squeeze_to_dim(weights, ndim_parent) return weights
def _compute_message_to_parent(self, index, m, u_X): m = list(m) for ind in range(len(m)): # Idea: Reshape the message array such that every other axis # will be summed and every other kept. shape_ind = self._plates_to_parent(index) + self.dims[ind] # Add variable dimensions to tiles tiles_ind = tiles + (1,)*len(self.dims[ind]) # Make shape tuples equal length shape_m = np.shape(m[ind]) (tiles_ind, shape, shape_m) = misc.make_equal_length(tiles_ind, shape_ind, shape_m) # Handle broadcasting rules for axes that have unit length in # the message (although the plate may be non-unit length). Also, # compute the corresponding plate_multiplier. r = 1 shape = list(shape) tiles_ind = list(tiles_ind) for j in range(len(shape)): if shape_m[j] == 1: r *= tiles_ind[j] shape[j] = 1 tiles_ind[j] = 1 # Combine the tuples by picking every other from tiles_ind and # every other from shape shape = functools.reduce(lambda x,y: x+y, zip(tiles_ind, shape)) # ..and reshape the array, that is, every other axis corresponds # to tiles and every other to plates/dimensions in parents m[ind] = np.reshape(m[ind], shape) # Sum over every other axis axes = tuple(range(0,len(shape),2)) m[ind] = r * np.sum(m[ind], axis=axes) # Remove extra leading axes ndim_parent = len(self.parents[index].get_shape(ind)) m[ind] = misc.squeeze_to_dim(m[ind], ndim_parent) return m
def _compute_message_to_parent(self, index, m, u_X): m = list(m) for ind in range(len(m)): # Idea: Reshape the message array such that every other axis # will be summed and every other kept. shape_ind = self._plates_to_parent(index) + self.dims[ind] # Add variable dimensions to tiles tiles_ind = tiles + (1, ) * len(self.dims[ind]) # Make shape tuples equal length shape_m = np.shape(m[ind]) (tiles_ind, shape, shape_m) = misc.make_equal_length(tiles_ind, shape_ind, shape_m) # Handle broadcasting rules for axes that have unit length in # the message (although the plate may be non-unit length). Also, # compute the corresponding broadcasting multiplier. r = 1 shape = list(shape) tiles_ind = list(tiles_ind) for j in range(len(shape)): if shape_m[j] == 1: r *= tiles_ind[j] shape[j] = 1 tiles_ind[j] = 1 # Combine the tuples by picking every other from tiles_ind and # every other from shape shape = functools.reduce(lambda x, y: x + y, zip(tiles_ind, shape)) # ..and reshape the array, that is, every other axis corresponds # to tiles and every other to plates/dimensions in parents m[ind] = np.reshape(m[ind], shape) # Sum over every other axis axes = tuple(range(0, len(shape), 2)) m[ind] = r * np.sum(m[ind], axis=axes) # Remove extra leading axes ndim_parent = len(self.parents[index].get_shape(ind)) m[ind] = misc.squeeze_to_dim(m[ind], ndim_parent) return m
def m_function(*args): lpdf = m(*args) # Log pdf only contains plate axes! plates_m = np.shape(lpdf) r = (self.broadcasting_multiplier(plates_self, plates_m, plates_mask, parent.plates) * self.broadcasting_multiplier(self.plates_multiplier, multiplier_parent)) axes_msg = misc.axes_to_collapse(plates_m, parent.plates) m[i] = misc.sum_multiply(mask_i, m[i], r, axis=axes_msg, keepdims=True) # Remove leading singular plates if the parent does not have # those plate axes. m[i] = misc.squeeze_to_dim(m[i], len(shape_parent))
def _compute_mask_to_parent(self, index, mask): # Idea: Reshape the message array such that every other axis # will be summed and every other kept. # Make plates equal length plates = self._plates_to_parent(index) shape_m = np.shape(mask) (plates, tiles_m, shape_m) = misc.make_equal_length(plates, tiles, shape_m) # Handle broadcasting rules for axes that have unit length in # the message (although the plate may be non-unit length). Also, # compute the corresponding broadcasting_multiplier. plates = list(plates) tiles_m = list(tiles_m) for j in range(len(plates)): if shape_m[j] == 1: plates[j] = 1 tiles_m[j] = 1 # Combine the tuples by picking every other from tiles_ind and # every other from shape shape = functools.reduce(lambda x,y: x+y, zip(tiles_m, plates)) # ..and reshape the array, that is, every other axis corresponds # to tiles and every other to plates/dimensions in parents mask = np.reshape(mask, shape) # Sum over every other axis axes = tuple(range(0,len(shape),2)) mask = np.any(mask, axis=axes) # Remove extra leading axes ndim_parent = len(self.parents[index].plates) mask = misc.squeeze_to_dim(mask, ndim_parent) return mask
def _message_to_parent(self, index): # Compute the message, check plates, apply mask and sum over some plates if index >= len(self.parents): raise ValueError("Parent index larger than the number of parents") # Compute the message and mask (m, mask) = self._get_message_and_mask_to_parent(index) mask = misc.squeeze(mask) # Plates in the mask plates_mask = np.shape(mask) # The parent we're sending the message to parent = self.parents[index] # Compact the message to a proper shape for i in range(len(m)): # Empty messages are given as None. We can ignore those. if m[i] is not None: # Plates in the message shape_m = np.shape(m[i]) dim_parent = len(parent.dims[i]) if dim_parent > 0: plates_m = shape_m[:-dim_parent] else: plates_m = shape_m # Compute the multiplier (multiply by the number of plates for # which the message, the mask and the parent have single # plates). Such a plate is meant to be broadcasted but because # the parent has singular plate axis, it won't broadcast (and # sum over it), so we need to multiply it. plates_self = self._plates_to_parent(index) try: r = self._plate_multiplier(plates_self, plates_m, plates_mask, parent.plates) except ValueError: raise ValueError("The plates of the message, the mask and " "parent[%d] node (%s) are not a " "broadcastable subset of the plates of " "this node (%s). The message has shape " "%s, meaning plates %s. The mask has " "plates %s. This node has plates %s with " "respect to the parent[%d], which has " "plates %s." % (index, parent.name, self.name, np.shape(m[i]), plates_m, plates_mask, plates_self, index, parent.plates)) # Add variable axes to the mask shape_mask = np.shape(mask) + (1,) * len(parent.dims[i]) mask_i = np.reshape(mask, shape_mask) # Sum over plates that are not in the message nor in the parent shape_parent = parent.get_shape(i) shape_msg = misc.broadcasted_shape(shape_m, shape_parent) axes_mask = misc.axes_to_collapse(shape_mask, shape_msg) mask_i = np.sum(mask_i, axis=axes_mask, keepdims=True) # Compute the masked message and sum over the plates that the # parent does not have. axes_msg = misc.axes_to_collapse(shape_msg, shape_parent) m[i] = misc.sum_multiply(mask_i, m[i], r, axis=axes_msg, keepdims=True) # Remove leading singular plates if the parent does not have # those plate axes. m[i] = misc.squeeze_to_dim(m[i], len(shape_parent)) return m