def _message_sum_multiply(plates_parent, dims_parent, *arrays): """ Compute message to parent and sum over plates. """ # The shape of the full message shapes = [np.shape(array) for array in arrays] shape_full = utils.broadcasted_shape(*shapes) # Find axes that should be summed shape_parent = plates_parent + dims_parent sum_axes = utils.axes_to_collapse(shape_full, shape_parent) # Compute the multiplier for cancelling the # plate-multiplier. Because we are summing over the # dimensions already in this function (for efficiency), we # need to cancel the effect of the plate-multiplier # applied in the message_to_parent function. r = 1 for j in sum_axes: if j >= 0 and j < len(plates_parent): r *= shape_full[j] elif j < 0 and j < -len(dims_parent): r *= shape_full[j] # Compute the sum-product m = utils.sum_multiply(*arrays, axis=sum_axes, sumaxis=True, keepdims=True) / r # Remove extra axes m = utils.squeeze_to_dim(m, len(shape_parent)) return m
def message_sum_multiply(plates_parent, dims_parent, *arrays): """ Compute message to parent and sum over plates. Divide by the plate multiplier. """ # The shape of the full message shapes = [np.shape(array) for array in arrays] shape_full = utils.broadcasted_shape(*shapes) # Find axes that should be summed shape_parent = plates_parent + dims_parent sum_axes = utils.axes_to_collapse(shape_full, shape_parent) # Compute the multiplier for cancelling the # plate-multiplier. Because we are summing over the # dimensions already in this function (for efficiency), we # need to cancel the effect of the plate-multiplier # applied in the message_to_parent function. r = 1 for j in sum_axes: if j >= 0 and j < len(plates_parent): r *= shape_full[j] elif j < 0 and j < -len(dims_parent): r *= shape_full[j] # Compute the sum-product m = utils.sum_multiply(*arrays, axis=sum_axes, sumaxis=True, keepdims=True) / r # Remove extra axes m = utils.squeeze_to_dim(m, len(shape_parent)) return m
def _message_to_parent(self, index): # Compute the message, check plates, apply mask and sum over some plates if index >= len(self.parents): raise ValueError("Parent index larger than the number of parents") # Compute the message and mask (m, mask) = self._get_message_and_mask_to_parent(index) mask = utils.squeeze(mask) # Plates in the mask plates_mask = np.shape(mask) # The parent we're sending the message to parent = self.parents[index] # Compact the message to a proper shape for i in range(len(m)): # Empty messages are given as None. We can ignore those. if m[i] is not None: # Plates in the message shape_m = np.shape(m[i]) dim_parent = len(parent.dims[i]) if dim_parent > 0: plates_m = shape_m[:-dim_parent] else: plates_m = shape_m # Compute the multiplier (multiply by the number of plates for # which the message, the mask and the parent have single # plates). Such a plate is meant to be broadcasted but because # the parent has singular plate axis, it won't broadcast (and # sum over it), so we need to multiply it. plates_self = self._plates_to_parent(index) try: r = self._plate_multiplier(plates_self, plates_m, plates_mask, parent.plates) except ValueError: raise ValueError("The plates of the message, the mask and " "parent[%d] node (%s) are not a " "broadcastable subset of the plates of " "this node (%s). The message has shape " "%s, meaning plates %s. The mask has " "plates %s. This node has plates %s with " "respect to the parent[%d], which has " "plates %s." % (index, parent.name, self.name, np.shape(m[i]), plates_m, plates_mask, plates_self, index, parent.plates)) # Add variable axes to the mask shape_mask = np.shape(mask) + (1,) * len(parent.dims[i]) mask_i = np.reshape(mask, shape_mask) # Sum over plates that are not in the message nor in the parent shape_parent = parent.get_shape(i) shape_msg = utils.broadcasted_shape(shape_m, shape_parent) axes_mask = utils.axes_to_collapse(shape_mask, shape_msg) mask_i = np.sum(mask_i, axis=axes_mask, keepdims=True) # Compute the masked message and sum over the plates that the # parent does not have. axes_msg = utils.axes_to_collapse(shape_msg, shape_parent) m[i] = utils.sum_multiply(mask_i, m[i], r, axis=axes_msg, keepdims=True) # Remove leading singular plates if the parent does not have # those plate axes. m[i] = utils.squeeze_to_dim(m[i], len(shape_parent)) return m