def _compute_moments(self, *u_parents): # Compute the number of plate axes for each node plate_counts0 = [(np.ndim(u_parent[0]) - len(keys)) for (keys,u_parent) in zip(self.in_keys, u_parents)] plate_counts1 = [(np.ndim(u_parent[1]) - 2*len(keys)) for (keys,u_parent) in zip(self.in_keys, u_parents)] # The number of plate axes for the output N0 = max(plate_counts0) N1 = max(plate_counts1) # The total number of unique keys used (keys are 0,1,...,N_keys-1) D = self.N_keys # # Compute the mean # out_all_keys = list(range(D+N0-1, D-1, -1)) + self.out_keys #nodes_dim_keys = self.nodes_dim_keys in_all_keys = [list(range(D+plate_count-1, D-1, -1)) + keys for (plate_count, keys) in zip(plate_counts0, self.in_keys)] u0 = [u[0] for u in u_parents] args = misc.zipper_merge(u0, in_all_keys) + [out_all_keys] x0 = np.einsum(*args) # # Compute the covariance # out_all_keys = (list(range(2*D+N1-1, 2*D-1, -1)) + [D+key for key in self.out_keys] + self.out_keys) in_all_keys = [list(range(2*D+plate_count-1, 2*D-1, -1)) + [D+key for key in node_keys] + node_keys for (plate_count, node_keys) in zip(plate_counts1, self.in_keys)] u1 = [u[1] for u in u_parents] args = misc.zipper_merge(u1, in_all_keys) + [out_all_keys] x1 = np.einsum(*args) if not self.gaussian_gamma: return [x0, x1] # Compute Gaussian-gamma specific moments x2 = 1 x3 = 0 for i in range(len(u_parents)): x2 = x2 * u_parents[i][2] x3 = x3 + u_parents[i][3] return [x0, x1, x2, x3]
def MultiMixture(thetas, *mixture_args, **kwargs): """Creates a mixture over several axes using as many categorical variables. The mixings are assumed to be separate, that is, inner mixings don't affect the parameters of outer mixings. """ thetas = list(thetas) N = len(thetas) # Add trailing plate axes to thetas because you assume that each # mixed axis is separate from the others. thetas = [ theta[(Ellipsis, ) + i * (None, )] for (i, theta) in enumerate(thetas) ] args = (thetas[:1] + list(misc.zipper_merge( (N - 1) * [Mixture], thetas[1:])) + list(mixture_args)) return Mixture(*args, **kwargs)
def MultiMixture(thetas, *mixture_args, **kwargs): """Creates a mixture over several axes using as many categorical variables. The mixings are assumed to be separate, that is, inner mixings don't affect the parameters of outer mixings. """ thetas = list(thetas) N = len(thetas) # Add trailing plate axes to thetas because you assume that each # mixed axis is separate from the others. thetas = [theta[(Ellipsis,) + i*(None,)] for (i, theta) in enumerate(thetas)] args = ( thetas[:1] + list(misc.zipper_merge((N-1) * [Mixture], thetas[1:])) + list(mixture_args) ) return Mixture(*args, **kwargs)
def _compute_function(self, *x_parents): # TODO: Add unit tests for this function (xs, alphas) = ( (x_parents, 1) if not self.gaussian_gamma else zip(*x_parents) ) # Add Ellipsis for the plates in_keys = [[Ellipsis] + k for k in self.in_keys] out_keys = [Ellipsis] + self.out_keys samples_and_keys = misc.zipper_merge(xs, in_keys) y = np.einsum(*(samples_and_keys + [out_keys])) return ( y if not self.gaussian_gamma else (y, misc.multiply(*alphas)) )
def _compute_message(*arrays, plates_from=(), plates_to=(), ndim=0): """ A general function for computing messages by sum-multiply The function computes the product of the input arrays and then sums to the requested plates. """ # Check that the plates broadcast properly if not misc.is_shape_subset(plates_to, plates_from): raise ValueError("plates_to must be broadcastable to plates_from") # Compute the explicit shape of the product shapes = [np.shape(array) for array in arrays] arrays_shape = misc.broadcasted_shape(*shapes) # Compute plates and dims that are present if ndim == 0: arrays_plates = arrays_shape dims = () else: arrays_plates = arrays_shape[:-ndim] dims = arrays_shape[-ndim:] # Compute the correction term. If some of the plates that should be # summed are actually broadcasted, one must multiply by the size of the # corresponding plate r = Node.broadcasting_multiplier(plates_from, arrays_plates, plates_to) # For simplicity, make the arrays equal ndim arrays = misc.make_equal_ndim(*arrays) # Keys for the input plates: (N-1, N-2, ..., 0) nplates = len(arrays_plates) in_plate_keys = list(range(nplates)) # Keys for the output plates out_plate_keys = [key for key in in_plate_keys if key < len(plates_to) and plates_to[-key-1] != 1] # Keys for the dims dim_keys = list(range(nplates, nplates+ndim)) # Total input and output keys in_keys = len(arrays) * [in_plate_keys + dim_keys] out_keys = out_plate_keys + dim_keys # Compute the sum-product with correction einsum_args = misc.zipper_merge(arrays, in_keys) + [out_keys] y = r * np.einsum(*einsum_args) # Reshape the result and apply correction nplates_result = min(len(plates_to), len(arrays_plates)) if nplates_result == 0: plates_result = [] else: plates_result = [min(plates_to[ind], arrays_plates[ind]) for ind in range(-nplates_result, 0)] y = np.reshape(y, plates_result + list(dims)) return y
def _compute_message(*arrays, plates_from=(), plates_to=(), ndim=0): """ A general function for computing messages by sum-multiply The function computes the product of the input arrays and then sums to the requested plates. """ # Check that the plates broadcast properly if not misc.is_shape_subset(plates_to, plates_from): raise ValueError("plates_to must be broadcastable to plates_from") # Compute the explicit shape of the product shapes = [np.shape(array) for array in arrays] arrays_shape = misc.broadcasted_shape(*shapes) # Compute plates and dims that are present if ndim == 0: arrays_plates = arrays_shape dims = () else: arrays_plates = arrays_shape[:-ndim] dims = arrays_shape[-ndim:] # Compute the correction term. If some of the plates that should be # summed are actually broadcasted, one must multiply by the size of the # corresponding plate r = Node.broadcasting_multiplier(plates_from, arrays_plates, plates_to) # For simplicity, make the arrays equal ndim arrays = misc.make_equal_ndim(*arrays) # Keys for the input plates: (N-1, N-2, ..., 0) nplates = len(arrays_plates) in_plate_keys = list(range(nplates)) # Keys for the output plates out_plate_keys = [ key for key in in_plate_keys if key < len(plates_to) and plates_to[-key - 1] != 1 ] # Keys for the dims dim_keys = list(range(nplates, nplates + ndim)) # Total input and output keys in_keys = len(arrays) * [in_plate_keys + dim_keys] out_keys = out_plate_keys + dim_keys # Compute the sum-product with correction einsum_args = misc.zipper_merge(arrays, in_keys) + [out_keys] y = r * np.einsum(*einsum_args) # Reshape the result and apply correction nplates_result = min(len(plates_to), len(arrays_plates)) if nplates_result == 0: plates_result = [] else: plates_result = [ min(plates_to[ind], arrays_plates[ind]) for ind in range(-nplates_result, 0) ] y = np.reshape(y, plates_result + list(dims)) return y
def _compute_moments(self, *u_parents): # Compute the number of plate axes for each node plate_counts0 = [ (np.ndim(u_parent[0]) - len(keys)) for (keys,u_parent) in zip(self.in_keys, u_parents) ] plate_counts1 = [ ( # Gaussian moments: Use second moments "matrix" (np.ndim(u_parent[1]) - 2*len(keys)) if not is_const else # Delta moments: Use first moment "vector" (np.ndim(u_parent[0]) - len(keys)) ) for (keys, u_parent, is_const) in zip( self.in_keys, u_parents, self.is_constant ) ] # The number of plate axes for the output N0 = max(plate_counts0) N1 = max(plate_counts1) # The total number of unique keys used (keys are 0,1,...,N_keys-1) D = self.N_keys # # Compute the mean # out_all_keys = list(range(D+N0-1, D-1, -1)) + self.out_keys #nodes_dim_keys = self.nodes_dim_keys in_all_keys = [list(range(D+plate_count-1, D-1, -1)) + keys for (plate_count, keys) in zip(plate_counts0, self.in_keys)] u0 = [u[0] for u in u_parents] args = misc.zipper_merge(u0, in_all_keys) + [out_all_keys] x0 = np.einsum(*args) # # Compute the covariance # out_all_keys = (list(range(2*D+N1-1, 2*D-1, -1)) + [D+key for key in self.out_keys] + self.out_keys) in_all_keys = [ x for (plate_count, node_keys, is_const) in zip( plate_counts1, self.in_keys, self.is_constant, ) for x in ( # Gaussian moments: Use the second moment [ list(range(2*D+plate_count-1, 2*D-1, -1)) + [D+key for key in node_keys] + node_keys ] if not is_const else # Delta moments: Use the first moment tiwce [ ( list(range(2*D+plate_count-1, 2*D-1, -1)) + [D+key for key in node_keys] ), ( list(range(2*D+plate_count-1, 2*D-1, -1)) + node_keys ), ] ) ] u1 = [ x for (u, is_const) in zip(u_parents, self.is_constant) for x in ( # Gaussian moments: Use the second moment [u[1]] if not is_const else # Delta moments: Use the first moment twice [u[0], u[0]] ) ] args = misc.zipper_merge(u1, in_all_keys) + [out_all_keys] x1 = np.einsum(*args) if not self.gaussian_gamma: return [x0, x1] # Compute Gaussian-gamma specific moments x2 = 1 x3 = 0 for i in range(len(u_parents)): x2 = x2 * (1 if self.is_constant[i] else u_parents[i][2]) x3 = x3 + (0 if self.is_constant[i] else u_parents[i][3]) return [x0, x1, x2, x3]