def _compute_moments(self, *u_parents): # TODO/FIXME: Unfortunately, np.concatenate doesn't support # broadcasting but moment messages may use broadcasting. # # WORKAROUND: Broadcast the arrays explcitly to have same shape # except for the concatenated axis. u = [] for i in range(len(self.dims)): # Fix plate axis to array axis axis = self._axis - len(self.dims[i]) # Find broadcasted shape ui_parents = [u_parent[i] for u_parent in u_parents] shapes = [list(np.shape(uip)) for uip in ui_parents] for i in range(len(shapes)): if len(shapes[i]) >= abs(axis): shapes[i][axis] = 1 ## shapes = [np.shape(uip[:axis]) + (1,) + np.shape(uip[(axis+1)]) ## if np.ndim(uip) >= abs(self._axis) else ## np.shape(uip) ## for uip in ui_parents] bc_shape = misc.broadcasted_shape(*shapes) # Concatenated axis must be broadcasted explicitly bc_shapes = [misc.broadcasted_shape(bc_shape, (length,) + (1,)*(abs(axis)-1)) for length in self._lengths] # Broadcast explicitly ui_parents = [uip * np.ones(shape) for (uip, shape) in zip(ui_parents, bc_shapes)] # Concatenate ui = np.concatenate(ui_parents, axis=axis) u.append(ui) return u
def message_sum_multiply(plates_parent, dims_parent, *arrays): """ Compute message to parent and sum over plates. Divide by the plate multiplier. """ # The shape of the full message shapes = [np.shape(array) for array in arrays] shape_full = misc.broadcasted_shape(*shapes) # Find axes that should be summed shape_parent = plates_parent + dims_parent sum_axes = misc.axes_to_collapse(shape_full, shape_parent) # Compute the multiplier for cancelling the # plate-multiplier. Because we are summing over the # dimensions already in this function (for efficiency), we # need to cancel the effect of the plate-multiplier # applied in the message_to_parent function. r = 1 for j in sum_axes: if j >= 0 and j < len(plates_parent): r *= shape_full[j] elif j < 0 and j < -len(dims_parent): r *= shape_full[j] # Compute the sum-product m = misc.sum_multiply(*arrays, axis=sum_axes, sumaxis=True, keepdims=True) / r # Remove extra axes m = misc.squeeze_to_dim(m, len(shape_parent)) return m
def _message_from_children(self, u_self=None): msg = [np.zeros(shape) for shape in self.dims] #msg = [np.array(0.0) for i in range(len(self.dims))] isfunction = None for (child, index) in self.children: m = child._message_to_parent(index, u_parent=u_self) if callable(m): if isfunction is False: raise NotImplementedError() elif isfunction is None: msg = m else: def join(m1, m2): return (m1[0] + m2[0], m1[1] + m2[1]) msg = lambda x: join(m(x), msg(x)) isfunction = True else: if isfunction is True: raise NotImplementedError() else: isfunction = False for i in range(len(self.dims)): if m[i] is not None: # Check broadcasting shapes sh = misc.broadcasted_shape( self.get_shape(i), np.shape(m[i])) try: # Try exploiting broadcasting rules msg[i] += m[i] except ValueError: msg[i] = msg[i] + m[i] return msg
def _message_from_children(self, u_self=None): msg = [np.zeros(shape) for shape in self.dims] #msg = [np.array(0.0) for i in range(len(self.dims))] isfunction = None for (child,index) in self.children: m = child._message_to_parent(index, u_parent=u_self) if callable(m): if isfunction is False: raise NotImplementedError() elif isfunction is None: msg = m else: def join(m1, m2): return (m1[0] + m2[0], m1[1] + m2[1]) msg = lambda x: join(m(x), msg(x)) isfunction = True else: if isfunction is True: raise NotImplementedError() else: isfunction = False for i in range(len(self.dims)): if m[i] is not None: # Check broadcasting shapes sh = misc.broadcasted_shape(self.get_shape(i), np.shape(m[i])) try: # Try exploiting broadcasting rules msg[i] += m[i] except ValueError: msg[i] = msg[i] + m[i] return msg
def _subplots(plotfunc, *args, fig=None, kwargs=None): """Create a collection of subplots Each subplot is created with the same plotting function. Inputs are given as pairs: (x, 3), (y, 2), ... where x,y,... are the input arrays and 3,2,... are the ndim parameters. The last ndim axes of each array are interpreted as a single element to the plotting function. All high-level plotting functions should wrap low-level plotting functions with this function in order to generate subplots for plates. """ if kwargs is None: kwargs = {} if fig is None: fig = plt.gcf() # Parse shape and plates of each input array shapes = [np.shape(x)[-n:] if n > 0 else () for (x,n) in args] plates = [np.shape(x)[:-n] if n > 0 else np.shape(x) for (x,n) in args] # Get the full grid shape of the subplots broadcasted_plates = misc.broadcasted_shape(*plates) # Subplot indexing layout M = np.prod(broadcasted_plates[-2::-2]) N = np.prod(broadcasted_plates[-1::-2]) strides_subplot = [np.prod(broadcasted_plates[(j+2)::2]) * N if ((len(broadcasted_plates)-j) % 2) == 0 else np.prod(broadcasted_plates[(j+2)::2]) for j in range(len(broadcasted_plates))] # Plot each subplot for ind in misc.nested_iterator(broadcasted_plates): # Get the list of inputs for this subplot broadcasted_args = [] for n in range(len(args)): i = misc.safe_indices(ind, plates[n]) broadcasted_args.append(args[n][0][i]) # Plot the subplot using the given function ind_subplot = np.einsum('i,i', ind, strides_subplot) axes = fig.add_subplot(M, N, ind_subplot+1) plotfunc(*broadcasted_args, axes=axes, **kwargs)
def _subplots(plotfunc, *args, fig=None, kwargs=None): """Create a collection of subplots Each subplot is created with the same plotting function. Inputs are given as pairs: (x, 3), (y, 2), ... where x,y,... are the input arrays and 3,2,... are the ndim parameters. The last ndim axes of each array are interpreted as a single element to the plotting function. All high-level plotting functions should wrap low-level plotting functions with this function in order to generate subplots for plates. """ if kwargs is None: kwargs = {} if fig is None: fig = plt.gcf() # Parse shape and plates of each input array shapes = [np.shape(x)[-n:] if n > 0 else () for (x, n) in args] plates = [np.shape(x)[:-n] if n > 0 else np.shape(x) for (x, n) in args] # Get the full grid shape of the subplots broadcasted_plates = misc.broadcasted_shape(*plates) # Subplot indexing layout M = np.prod(broadcasted_plates[-2::-2]) N = np.prod(broadcasted_plates[-1::-2]) strides_subplot = [ np.prod(broadcasted_plates[(j + 2)::2]) * N if ((len(broadcasted_plates) - j) % 2) == 0 else np.prod(broadcasted_plates[(j + 2)::2]) for j in range(len(broadcasted_plates)) ] # Plot each subplot for ind in misc.nested_iterator(broadcasted_plates): # Get the list of inputs for this subplot broadcasted_args = [] for n in range(len(args)): i = misc.safe_indices(ind, plates[n]) broadcasted_args.append(args[n][0][i]) # Plot the subplot using the given function ind_subplot = np.einsum('i,i', ind, strides_subplot) axes = fig.add_subplot(M, N, ind_subplot + 1) plotfunc(*broadcasted_args, axes=axes, **kwargs)
def _plate_multiplier(plates, *args): """ Compute the plate multiplier for given shapes. The first shape is compared to all other shapes (using NumPy broadcasting rules). All the elements which are non-unit in the first shape but 1 in all other shapes are multiplied together. This method is used, for instance, for computing a correction factor for messages to parents: If this node has non-unit plates that are unit plates in the parent, those plates are summed. However, if the message has unit axis for that plate, it should be first broadcasted to the plates of this node and then summed to the plates of the parent. In order to avoid this broadcasting and summing, it is more efficient to just multiply by the correct factor. This method computes that factor. The first argument is the full plate shape of this node (with respect to the parent). The other arguments are the shape of the message array and the plates of the parent (with respect to this node). """ # Check broadcasting of the shapes for arg in args: misc.broadcasted_shape(plates, arg) # Check that each arg-plates are a subset of plates? for arg in args: if not misc.is_shape_subset(arg, plates): raise ValueError("The shapes in args are not a sub-shape of " "plates.") r = 1 for j in range(-len(plates),0): mult = True for arg in args: # if -j <= len(arg) and arg[j] != 1: if not (-j > len(arg) or arg[j] == 1): mult = False if mult: r *= plates[j] return r
def _total_plates(cls, plates, *parent_plates): if plates is None: # By default, use the minimum number of plates determined # from the parent nodes try: return misc.broadcasted_shape(*parent_plates) except ValueError: raise ValueError("The plates of the parents do not broadcast.") else: # Check that the parent_plates are a subset of plates. for (ind, p) in enumerate(parent_plates): if not misc.is_shape_subset(p, plates): raise ValueError("The plates %s of the parents " "are not broadcastable to the given " "plates %s." % (p, plates)) return plates
def _message_from_children(self): msg = [np.zeros(shape) for shape in self.dims] #msg = [np.array(0.0) for i in range(len(self.dims))] for (child,index) in self.children: m = child._message_to_parent(index) for i in range(len(self.dims)): if m[i] is not None: # Check broadcasting shapes sh = misc.broadcasted_shape(self.get_shape(i), np.shape(m[i])) try: # Try exploiting broadcasting rules msg[i] += m[i] except ValueError: msg[i] = msg[i] + m[i] return msg
def _compute_moments(self, u_Z): """ Compute the moments given the moments of the parents. """ # Add time axis to p0 p0 = u_Z[0][..., None, :] # Sum joint probability arrays to marginal probability vectors zz = u_Z[1] p = np.sum(zz, axis=-2) # Broadcast p0 and p to same shape, except the time axis plates_p0 = np.shape(p0)[:-2] plates_p = np.shape(p)[:-2] shape = misc.broadcasted_shape(plates_p0, plates_p) + (1, 1) p0 = p0 * np.ones(shape) p = p * np.ones(shape) # Concatenate P = np.concatenate((p0, p), axis=-2) return [P]
def _message_to_parent(self, index, u_parent=None): """ Compute the message and mask to a parent node. """ if self.is_constant[index]: raise NotImplementedError( "Message to DeltaMoments parent not yet implemented." ) # Check index if index >= len(self.parents): raise ValueError("Parent index larger than the number of parents") # Get messages from other parents and children u_parents = self._message_from_parents(exclude=index) m = self._message_from_children() mask = self.mask # Normally we don't need to care about masks when computing the # message. However, in this node we want to avoid computing huge message # arrays so we sum some axes already here. Thus, we need to apply the # mask. # # Actually, we don't need to care about masks because the message from # children has already been masked. parent = self.parents[index] # # Compute the first message # msg = [None, None] # Compute the two messages for ind in range(2): # The total number of keys for the non-plate dimensions N = (ind+1) * self.N_keys parent_num_dims = len(parent.dims[ind]) parent_num_plates = len(parent.plates) parent_plate_keys = list(range(N + parent_num_plates, N, -1)) parent_dim_keys = self.in_keys[index] if ind == 1: parent_dim_keys = ([key + self.N_keys for key in self.in_keys[index]] + parent_dim_keys) args = [] # This variable counts the maximum number of plates of the # arguments, thus it will tell the number of plates in the result # (if the artificially added plates above were ignored). result_num_plates = 0 result_plates = () # Mask and its keysr mask_num_plates = np.ndim(mask) mask_plates = np.shape(mask) mask_plate_keys = list(range(N + mask_num_plates, N, -1)) result_num_plates = max(result_num_plates, mask_num_plates) result_plates = misc.broadcasted_shape(result_plates, mask_plates) # Moments and keys of other parents for (k, u) in enumerate(u_parents): if k != index: num_dims = ( (ind+1) * len(self.in_keys[k]) if not self.is_constant[k] else len(self.in_keys[k]) ) ui = ( u[ind] if not self.is_constant[k] else u[0] ) num_plates = np.ndim(ui) - num_dims plates = np.shape(ui)[:num_plates] plate_keys = list(range(N + num_plates, N, -1)) if ind == 0: args.append(ui) args.append(plate_keys + self.in_keys[k]) else: in_keys2 = [key + self.N_keys for key in self.in_keys[k]] if not self.is_constant[k]: # Gaussian moments: Use second moment once args.append(ui) args.append(plate_keys + in_keys2 + self.in_keys[k]) else: # Delta moments: Use first moment twice args.append(ui) args.append(plate_keys + self.in_keys[k]) args.append(ui) args.append(plate_keys + in_keys2) result_num_plates = max(result_num_plates, num_plates) result_plates = misc.broadcasted_shape(result_plates, plates) # Message and keys from children child_num_dims = (ind+1) * len(self.out_keys) child_num_plates = np.ndim(m[ind]) - child_num_dims child_plates = np.shape(m[ind])[:child_num_plates] child_plate_keys = list(range(N + child_num_plates, N, -1)) child_dim_keys = self.out_keys if ind == 1: child_dim_keys = ([key + self.N_keys for key in self.out_keys] + child_dim_keys) args.append(m[ind]) args.append(child_plate_keys + child_dim_keys) result_num_plates = max(result_num_plates, child_num_plates) result_plates = misc.broadcasted_shape(result_plates, child_plates) # Output keys, that is, the keys of the parent[index] parent_keys = parent_plate_keys + parent_dim_keys # Performance trick: Check which axes can be summed because they # have length 1 or are non-existing in parent[index]. Thus, remove # keys corresponding to unit length axes in parent[index] so that # einsum sums over those axes. After computations, these axes must # be added back in order to get the correct shape for the message. # Also, remove axes/keys that are in output (parent[index]) but not in # any inputs (children and other parents). parent_shape = parent.get_shape(ind) removed_axes = [] for j in range(len(parent_keys)): if parent_shape[j] == 1: # Remove the key (take into account the number of keys that # have already been removed) del parent_keys[j-len(removed_axes)] removed_axes.append(j) else: # Remove the key if it doesn't appear in any of the # messages from children or other parents. if not np.any([parent_keys[j-len(removed_axes)] in keys for keys in args[1::2]]): del parent_keys[j-len(removed_axes)] removed_axes.append(j) args.append(parent_keys) # THE BEEF: Compute the message msg[ind] = np.einsum(*args) # Find the correct shape for the message array message_shape = list(np.shape(msg[ind])) # First, add back the axes with length 1 for ax in removed_axes: message_shape.insert(ax, 1) # Second, remove leading axes for plates that were not present in # the child nor other parents' messages. This is not really # necessary, but it is just elegant to remove the leading unit # length axes that we added artificially at the beginning just # because we wanted the key mapping to be simple. if parent_num_plates > result_num_plates: del message_shape[:(parent_num_plates-result_num_plates)] # Then, the actual reshaping msg[ind] = np.reshape(msg[ind], message_shape) # Broadcasting is not supported for variable dimensions, thus force # explicit correct shape for variable dimensions var_dims = parent.dims[ind] msg[ind] = msg[ind] * np.ones(var_dims) # Apply plate multiplier: If this node has non-unit plates that are # unit plates in the parent, those plates are summed. However, if # the message has unit axis for that plate, it should be first # broadcasted to the plates of this node and then summed to the # plates of the parent. In order to avoid this broadcasting and # summing, it is more efficient to just multiply by the correct # factor. r = self.broadcasting_multiplier(self.plates, result_plates, parent.plates) if r != 1: msg[ind] *= r if self.gaussian_gamma: alphas = [ (u_parents[i][2] if not is_const else 1.0) for (i, is_const) in zip(range(len(u_parents)), self.is_constant) if i != index ] m2 = self._compute_message(m[2], mask, *alphas, ndim=0, plates_from=self.plates, plates_to=parent.plates) m3 = self._compute_message(m[3], mask, ndim=0, plates_from=self.plates, plates_to=parent.plates) msg = msg + [m2, m3] return msg
def _message_to_parent(self, index): """ Compute the message and mask to a parent node. """ # Check index if index >= len(self.parents): raise ValueError("Parent index larger than the number of parents") # Get messages from other parents and children u_parents = self._message_from_parents(exclude=index) m = self._message_from_children() mask = self.mask # Normally we don't need to care about masks when computing the # message. However, in this node we want to avoid computing huge message # arrays so we sum some axes already here. Thus, we need to apply the # mask. parent = self.parents[index] # # Compute the first message # msg = [None, None] # Compute the two messages for ind in range(2): # The total number of keys for the non-plate dimensions N = (ind+1) * self.N_keys # Add an array of ones to ensure proper shape and number of # plates. Note that this adds an axis for each plate. At the end, we # want to remove axes that were created only because of this parent_num_dims = len(parent.dims[ind]) parent_num_plates = len(parent.plates) parent_plate_keys = list(range(N + parent_num_plates, N, -1)) parent_dim_keys = self.in_keys[index] if ind == 1: parent_dim_keys = ([key + self.N_keys for key in self.in_keys[index]] + parent_dim_keys) args = [] args.append(np.ones((1,)*parent_num_plates + parent.dims[ind])) args.append(parent_plate_keys + parent_dim_keys) # This variable counts the maximum number of plates of the # arguments, thus it will tell the number of plates in the result # (if the artificially added plates above were ignored). result_num_plates = 0 result_plates = () # Mask and its keysr mask_num_plates = np.ndim(mask) mask_plates = np.shape(mask) mask_plate_keys = list(range(N + mask_num_plates, N, -1)) result_num_plates = max(result_num_plates, mask_num_plates) result_plates = misc.broadcasted_shape(result_plates, mask_plates) args.append(mask) args.append(mask_plate_keys) # Moments and keys of other parents for (k, u) in enumerate(u_parents): if k != index: num_dims = (ind+1) * len(self.in_keys[k]) num_plates = np.ndim(u[ind]) - num_dims plates = np.shape(u[ind])[:num_plates] plate_keys = list(range(N + num_plates, N, -1)) dim_keys = self.in_keys[k] if ind == 1: dim_keys = ([key + self.N_keys for key in self.in_keys[k]] + dim_keys) args.append(u[ind]) args.append(plate_keys + dim_keys) result_num_plates = max(result_num_plates, num_plates) result_plates = misc.broadcasted_shape(result_plates, plates) # Message and keys from children child_num_dims = (ind+1) * len(self.out_keys) child_num_plates = np.ndim(m[ind]) - child_num_dims child_plates = np.shape(m[ind])[:child_num_plates] child_plate_keys = list(range(N + child_num_plates, N, -1)) child_dim_keys = self.out_keys if ind == 1: child_dim_keys = ([key + self.N_keys for key in self.out_keys] + child_dim_keys) args.append(m[ind]) args.append(child_plate_keys + child_dim_keys) result_num_plates = max(result_num_plates, child_num_plates) result_plates = misc.broadcasted_shape(result_plates, child_plates) # Output keys, that is, the keys of the parent[index] parent_keys = parent_plate_keys + parent_dim_keys # Performance trick: Check which axes can be summed because they # have length 1 or are non-existing in parent[index]. Thus, remove # keys corresponding to unit length axes in parent[index] so that # einsum sums over those axes. After computations, these axes must # be added back in order to get the correct shape for the message. parent_shape = parent.get_shape(ind) removed_axes = [] for j in range(len(parent_keys)): if parent_shape[j] == 1: # Remove the key (take into account the number of keys that # have already been removed) del parent_keys[j-len(removed_axes)] removed_axes.append(j) args.append(parent_keys) # THE BEEF: Compute the message msg[ind] = np.einsum(*args) # Find the correct shape for the message array message_shape = list(np.shape(msg[ind])) # First, add back the axes with length 1 for ax in removed_axes: message_shape.insert(ax, 1) # Second, remove leading axes for plates that were not present in # the child nor other parents' messages. This is not really # necessary, but it is just elegant to remove the leading unit # length axes that we added artificially at the beginning just # because we wanted the key mapping to be simple. if parent_num_plates > result_num_plates: del message_shape[:(parent_num_plates-result_num_plates)] # Then, the actual reshaping msg[ind] = np.reshape(msg[ind], message_shape) # Apply plate multiplier: If this node has non-unit plates that are # unit plates in the parent, those plates are summed. However, if # the message has unit axis for that plate, it should be first # broadcasted to the plates of this node and then summed to the # plates of the parent. In order to avoid this broadcasting and # summing, it is more efficient to just multiply by the correct # factor. r = self.broadcasting_multiplier(self.plates, result_plates, parent.plates) if r != 1: msg[ind] *= r if self.gaussian_gamma: alphas = [u_parents[i][2] for i in range(len(u_parents)) if i != index] ## logalphas = [u_parents[i][3] ## for i in range(len(u_parents)) if i != index] m2 = self._compute_message(mask, m[2], *alphas, ndim=0, plates_from=self.plates, plates_to=parent.plates) m3 = self._compute_message(mask, m[3],#*logalphas, ndim=0, plates_from=self.plates, plates_to=parent.plates) msg = msg + [m2, m3] return msg
def _compute_message(*arrays, plates_from=(), plates_to=(), ndim=0): """ A general function for computing messages by sum-multiply The function computes the product of the input arrays and then sums to the requested plates. """ # Check that the plates broadcast properly if not misc.is_shape_subset(plates_to, plates_from): raise ValueError("plates_to must be broadcastable to plates_from") # Compute the explicit shape of the product shapes = [np.shape(array) for array in arrays] arrays_shape = misc.broadcasted_shape(*shapes) # Compute plates and dims that are present if ndim == 0: arrays_plates = arrays_shape dims = () else: arrays_plates = arrays_shape[:-ndim] dims = arrays_shape[-ndim:] # Compute the correction term. If some of the plates that should be # summed are actually broadcasted, one must multiply by the size of the # corresponding plate r = Node.broadcasting_multiplier(plates_from, arrays_plates, plates_to) # For simplicity, make the arrays equal ndim arrays = misc.make_equal_ndim(*arrays) # Keys for the input plates: (N-1, N-2, ..., 0) nplates = len(arrays_plates) in_plate_keys = list(range(nplates)) # Keys for the output plates out_plate_keys = [key for key in in_plate_keys if key < len(plates_to) and plates_to[-key-1] != 1] # Keys for the dims dim_keys = list(range(nplates, nplates+ndim)) # Total input and output keys in_keys = len(arrays) * [in_plate_keys + dim_keys] out_keys = out_plate_keys + dim_keys # Compute the sum-product with correction einsum_args = misc.zipper_merge(arrays, in_keys) + [out_keys] y = r * np.einsum(*einsum_args) # Reshape the result and apply correction nplates_result = min(len(plates_to), len(arrays_plates)) if nplates_result == 0: plates_result = [] else: plates_result = [min(plates_to[ind], arrays_plates[ind]) for ind in range(-nplates_result, 0)] y = np.reshape(y, plates_result + list(dims)) return y
def _message_to_parent(self, index): # Compute the message, check plates, apply mask and sum over some plates if index >= len(self.parents): raise ValueError("Parent index larger than the number of parents") # Compute the message and mask (m, mask) = self._get_message_and_mask_to_parent(index) mask = misc.squeeze(mask) # Plates in the mask plates_mask = np.shape(mask) # The parent we're sending the message to parent = self.parents[index] # Plates with respect to the parent plates_self = self._plates_to_parent(index) # Plate multiplier of the parent multiplier_parent = self._plates_multiplier_from_parent(index) # Check if m is a logpdf function (for black-box variational inference) if callable(m): def m_function(*args): lpdf = m(*args) # Log pdf only contains plate axes! plates_m = np.shape(lpdf) r = (self.broadcasting_multiplier(plates_self, plates_m, plates_mask, parent.plates) * self.broadcasting_multiplier(self.plates_multiplier, multiplier_parent)) axes_msg = misc.axes_to_collapse(plates_m, parent.plates) m[i] = misc.sum_multiply(mask_i, m[i], r, axis=axes_msg, keepdims=True) # Remove leading singular plates if the parent does not have # those plate axes. m[i] = misc.squeeze_to_dim(m[i], len(shape_parent)) return m_function raise NotImplementedError() # Compact the message to a proper shape for i in range(len(m)): # Empty messages are given as None. We can ignore those. if m[i] is not None: try: r = self.broadcasting_multiplier(self.plates_multiplier, multiplier_parent) except: raise ValueError("The plate multipliers are incompatible. " "This node (%s) has %s and parent[%d] " "(%s) has %s" % (self.name, self.plates_multiplier, index, parent.name, multiplier_parent)) ndim = len(parent.dims[i]) # Source and target shapes if ndim > 0: dims = misc.broadcasted_shape(np.shape(m[i])[-ndim:], parent.dims[i]) from_shape = plates_self + dims else: from_shape = plates_self to_shape = parent.get_shape(i) # Add variable axes to the mask mask_i = misc.add_trailing_axes(mask, ndim) # Apply mask and sum plate axes as necessary (and apply plate # multiplier) m[i] = r * misc.sum_multiply_to_plates(m[i], mask_i, to_plates=to_shape, from_plates=from_shape, ndim=0) return m
def _message_to_parent(self, index): # Compute the message, check plates, apply mask and sum over some plates if index >= len(self.parents): raise ValueError("Parent index larger than the number of parents") # Compute the message and mask (m, mask) = self._get_message_and_mask_to_parent(index) mask = misc.squeeze(mask) # Plates in the mask plates_mask = np.shape(mask) # The parent we're sending the message to parent = self.parents[index] # Plates with respect to the parent plates_self = self._plates_to_parent(index) # Plate multiplier of the parent multiplier_parent = self._plates_multiplier_from_parent(index) # Check if m is a logpdf function (for black-box variational inference) if callable(m): def m_function(*args): lpdf = m(*args) # Log pdf only contains plate axes! plates_m = np.shape(lpdf) r = (self.broadcasting_multiplier(plates_self, plates_m, plates_mask, parent.plates) * self.broadcasting_multiplier(self.plates_multiplier, multiplier_parent)) axes_msg = misc.axes_to_collapse(plates_m, parent.plates) m[i] = misc.sum_multiply(mask_i, m[i], r, axis=axes_msg, keepdims=True) # Remove leading singular plates if the parent does not have # those plate axes. m[i] = misc.squeeze_to_dim(m[i], len(shape_parent)) return m_function raise NotImplementedError() # Compact the message to a proper shape for i in range(len(m)): # Empty messages are given as None. We can ignore those. if m[i] is not None: try: r = self.broadcasting_multiplier(self.plates_multiplier, multiplier_parent) except: raise ValueError("The plate multipliers are incompatible. " "This node (%s) has %s and parent[%d] " "(%s) has %s" % (self.name, self.plates_multiplier, index, parent.name, multiplier_parent)) ndim = len(parent.dims[i]) # Source and target shapes if ndim > 0: dims = misc.broadcasted_shape( np.shape(m[i])[-ndim:], parent.dims[i]) from_shape = plates_self + dims else: from_shape = plates_self to_shape = parent.get_shape(i) # Add variable axes to the mask mask_i = misc.add_trailing_axes(mask, ndim) # Apply mask and sum plate axes as necessary (and apply plate # multiplier) m[i] = r * misc.sum_multiply_to_plates(m[i], mask_i, to_plates=to_shape, from_plates=from_shape, ndim=0) return m
def _compute_message(*arrays, plates_from=(), plates_to=(), ndim=0): """ A general function for computing messages by sum-multiply The function computes the product of the input arrays and then sums to the requested plates. """ # Check that the plates broadcast properly if not misc.is_shape_subset(plates_to, plates_from): raise ValueError("plates_to must be broadcastable to plates_from") # Compute the explicit shape of the product shapes = [np.shape(array) for array in arrays] arrays_shape = misc.broadcasted_shape(*shapes) # Compute plates and dims that are present if ndim == 0: arrays_plates = arrays_shape dims = () else: arrays_plates = arrays_shape[:-ndim] dims = arrays_shape[-ndim:] # Compute the correction term. If some of the plates that should be # summed are actually broadcasted, one must multiply by the size of the # corresponding plate r = Node.broadcasting_multiplier(plates_from, arrays_plates, plates_to) # For simplicity, make the arrays equal ndim arrays = misc.make_equal_ndim(*arrays) # Keys for the input plates: (N-1, N-2, ..., 0) nplates = len(arrays_plates) in_plate_keys = list(range(nplates)) # Keys for the output plates out_plate_keys = [ key for key in in_plate_keys if key < len(plates_to) and plates_to[-key - 1] != 1 ] # Keys for the dims dim_keys = list(range(nplates, nplates + ndim)) # Total input and output keys in_keys = len(arrays) * [in_plate_keys + dim_keys] out_keys = out_plate_keys + dim_keys # Compute the sum-product with correction einsum_args = misc.zipper_merge(arrays, in_keys) + [out_keys] y = r * np.einsum(*einsum_args) # Reshape the result and apply correction nplates_result = min(len(plates_to), len(arrays_plates)) if nplates_result == 0: plates_result = [] else: plates_result = [ min(plates_to[ind], arrays_plates[ind]) for ind in range(-nplates_result, 0) ] y = np.reshape(y, plates_result + list(dims)) return y
def _message_to_parent(self, index): # Compute the message, check plates, apply mask and sum over some plates if index >= len(self.parents): raise ValueError("Parent index larger than the number of parents") # Compute the message and mask (m, mask) = self._get_message_and_mask_to_parent(index) mask = misc.squeeze(mask) # Plates in the mask plates_mask = np.shape(mask) # The parent we're sending the message to parent = self.parents[index] # Compact the message to a proper shape for i in range(len(m)): # Empty messages are given as None. We can ignore those. if m[i] is not None: # Plates in the message shape_m = np.shape(m[i]) dim_parent = len(parent.dims[i]) if dim_parent > 0: plates_m = shape_m[:-dim_parent] else: plates_m = shape_m # Compute the multiplier (multiply by the number of plates for # which the message, the mask and the parent have single # plates). Such a plate is meant to be broadcasted but because # the parent has singular plate axis, it won't broadcast (and # sum over it), so we need to multiply it. plates_self = self._plates_to_parent(index) try: r = self._plate_multiplier(plates_self, plates_m, plates_mask, parent.plates) except ValueError: raise ValueError("The plates of the message, the mask and " "parent[%d] node (%s) are not a " "broadcastable subset of the plates of " "this node (%s). The message has shape " "%s, meaning plates %s. The mask has " "plates %s. This node has plates %s with " "respect to the parent[%d], which has " "plates %s." % (index, parent.name, self.name, np.shape(m[i]), plates_m, plates_mask, plates_self, index, parent.plates)) # Add variable axes to the mask shape_mask = np.shape(mask) + (1,) * len(parent.dims[i]) mask_i = np.reshape(mask, shape_mask) # Sum over plates that are not in the message nor in the parent shape_parent = parent.get_shape(i) shape_msg = misc.broadcasted_shape(shape_m, shape_parent) axes_mask = misc.axes_to_collapse(shape_mask, shape_msg) mask_i = np.sum(mask_i, axis=axes_mask, keepdims=True) # Compute the masked message and sum over the plates that the # parent does not have. axes_msg = misc.axes_to_collapse(shape_msg, shape_parent) m[i] = misc.sum_multiply(mask_i, m[i], r, axis=axes_msg, keepdims=True) # Remove leading singular plates if the parent does not have # those plate axes. m[i] = misc.squeeze_to_dim(m[i], len(shape_parent)) return m
def sum_plates(V, *plates): full_plates = misc.broadcasted_shape(*plates) r = self.node_X.broadcasting_multiplier(full_plates, np.shape(V)) return r * np.sum(V)