def _compute_message_to_parent(self, index, m_child, u_Z, u_X): """ """ if index == 0: m0 = 0 # Compute Child * X, sum over variable axes and move the gated axis # to be the last. Need to do some shape changing in order to make # Child and X to broadcast properly. for i in range(len(m_child)): ndim = len(self.dims[i]) c = m_child[i][..., None] c = misc.moveaxis(c, -1, -ndim - 1) gated_axis = self.gated_plate - ndim x = u_X[i] if np.ndim(x) < abs(gated_axis): x = np.expand_dims(x, -ndim - 1) else: x = misc.moveaxis(x, gated_axis, -ndim - 1) axes = tuple(range(-ndim, 0)) m0 = m0 + misc.sum_product(c, x, axes_to_sum=axes) # Make sure the variable axis does not use broadcasting m0 = m0 * np.ones(self.K) # Send the message m = [m0] return m elif index == 1: m = [] for i in range(len(m_child)): # Make the moments of Z and the message from children # broadcastable. The gated plate is handled as the last axis in # the arrays and moved to the correct position at the end. # Add variable axes to Z moments ndim = len(self.dims[i]) z = misc.add_trailing_axes(u_Z[0], ndim) z = misc.moveaxis(z, -ndim - 1, -1) # Axis index of the gated plate gated_axis = self.gated_plate - ndim # Add the gate axis to the message from the children c = misc.add_trailing_axes(m_child[i], 1) # Compute the message to parent mi = z * c # Add extra axes if necessary if np.ndim(mi) < abs(gated_axis): mi = misc.add_leading_axes(mi, abs(gated_axis) - np.ndim(mi)) # Move the axis to the correct position mi = misc.moveaxis(mi, -1, gated_axis) m.append(mi) return m else: raise ValueError("Invalid parent index")
def _compute_message_to_parent(self, index, m_child, u_Z, u_X): """ """ if index == 0: m0 = 0 # Compute Child * X, sum over variable axes and move the gated axis # to be the last. Need to do some shape changing in order to make # Child and X to broadcast properly. for i in range(len(m_child)): ndim = len(self.dims[i]) c = m_child[i][...,None] c = misc.moveaxis(c, -1, -ndim-1) gated_axis = self.gated_plate - ndim x = u_X[i] if np.ndim(x) < abs(gated_axis): x = np.expand_dims(x, -ndim-1) else: x = misc.moveaxis(x, gated_axis, -ndim-1) axes = tuple(range(-ndim, 0)) m0 = m0 + misc.sum_product(c, x, axes_to_sum=axes) # Make sure the variable axis does not use broadcasting m0 = m0 * np.ones(self.K) # Send the message m = [m0] return m elif index == 1: m = [] for i in range(len(m_child)): # Make the moments of Z and the message from children # broadcastable. The gated plate is handled as the last axis in # the arrays and moved to the correct position at the end. # Add variable axes to Z moments ndim = len(self.dims[i]) z = misc.add_trailing_axes(u_Z[0], ndim) z = misc.moveaxis(z, -ndim-1, -1) # Axis index of the gated plate gated_axis = self.gated_plate - ndim # Add the gate axis to the message from the children c = misc.add_trailing_axes(m_child[i], 1) # Compute the message to parent mi = z * c # Add extra axes if necessary if np.ndim(mi) < abs(gated_axis): mi = misc.add_leading_axes(mi, abs(gated_axis) - np.ndim(mi)) # Move the axis to the correct position mi = misc.moveaxis(mi, -1, gated_axis) m.append(mi) return m else: raise ValueError("Invalid parent index")
def _compute_moments(self, u_Z, u_X): """ """ u = [] for i in range(len(u_X)): # Make the moments of Z and X broadcastable and move the gated plate # to be the last axis in the moments, then sum-product over that # axis ndim = len(self.dims[i]) z = misc.add_trailing_axes(u_Z[0], ndim) z = misc.moveaxis(z, -ndim - 1, -1) gated_axis = self.gated_plate - ndim if np.ndim(u_X[i]) < abs(gated_axis): x = misc.add_trailing_axes(u_X[i], 1) else: x = misc.moveaxis(u_X[i], gated_axis, -1) ui = misc.sum_product(z, x, axes_to_sum=-1) u.append(ui) return u
def _compute_moments(self, u_Z, u_X): """ """ u = [] for i in range(len(u_X)): # Make the moments of Z and X broadcastable and move the gated plate # to be the last axis in the moments, then sum-product over that # axis ndim = len(self.dims[i]) z = misc.add_trailing_axes(u_Z[0], ndim) z = misc.moveaxis(z, -ndim-1, -1) gated_axis = self.gated_plate - ndim if np.ndim(u_X[i]) < abs(gated_axis): x = misc.add_trailing_axes(u_X[i], 1) else: x = misc.moveaxis(u_X[i], gated_axis, -1) ui = misc.sum_product(z, x, axes_to_sum=-1) u.append(ui) return u
def _compute_moments(self, u_Lambda, u_alpha): Lambda = u_Lambda[0] logdet_Lambda = u_Lambda[1] alpha = misc.add_trailing_axes(u_alpha[0], 2*self._moments.ndim) logalpha = u_alpha[1] u0 = Lambda * alpha u1 = logdet_Lambda + np.prod(self._moments.shape) * logalpha return [u0, u1]
def _compute_moments(self, u_Lambda, u_alpha): Lambda = u_Lambda[0] logdet_Lambda = u_Lambda[1] alpha = misc.add_trailing_axes(u_alpha[0], 2 * self._moments.ndim) logalpha = u_alpha[1] u0 = Lambda * alpha u1 = logdet_Lambda + np.prod(self._moments.shape) * logalpha return [u0, u1]
def _set_moments(self, u, mask=True, broadcast=True): self._check_shape(u, broadcast=broadcast) # Store the computed moments u but do not change moments for # observations, i.e., utilize the mask. for ind in range(len(u)): # Add axes to the mask for the variable dimensions (mask # contains only axes for the plates). u_mask = misc.add_trailing_axes(mask, self.ndims[ind]) # Enlarge self.u[ind] as necessary so that it can store the # broadcasted result. sh = misc.broadcasted_shape_from_arrays(self.u[ind], u[ind], u_mask) self.u[ind] = misc.repeat_to_shape(self.u[ind], sh) # TODO/FIXME/BUG: The mask of observations is not used, observations # may be overwritten!!! ??? # Hah, this function is used to set the observations! The caller # should be careful what mask he uses! If you want to set only # latent variables, then use such a mask. # Use mask to update only unobserved plates and keep the # observed as before np.copyto(self.u[ind], u[ind], where=u_mask) # Make sure u has the correct number of dimensions: shape = self.get_shape(ind) ndim = len(shape) ndim_u = np.ndim(self.u[ind]) if ndim > ndim_u: self.u[ind] = misc.add_leading_axes(u[ind], ndim - ndim_u) elif ndim < ndim_u: # This should not ever happen because we already checked the # shape at the beginning of the function. raise RuntimeError( "This error should not happen. Fix shape checking." "The size of the variable %s's %s-th moment " "array is %s which is larger than it should " "be, that is, %s, based on the plates %s and " "dimension %s. Check that you have provided " "plates properly." % (self.name, ind, np.shape(self.u[ind]), shape, self.plates, self.dims[ind]))
def compute_fixed_moments(self, Lambda, gradient=None): """ Compute moments for fixed x. """ L = linalg.chol(Lambda, ndim=self.ndim) ldet = linalg.chol_logdet(L, ndim=self.ndim) u = [Lambda, ldet] if gradient is None: return u du0 = gradient[0] du1 = (misc.add_trailing_axes(gradient[1], 2 * self.ndim) * linalg.chol_inv(L, ndim=self.ndim)) du = du0 + du1 return (u, du)
def _compute_message_to_parent(self, index, m, u_Lambda, u_alpha): if index == 0: alpha = misc.add_trailing_axes(u_alpha[0], 2 * self._moments.ndim) logalpha = u_alpha[1] m0 = m[0] * alpha m1 = m[1] return [m0, m1] if index == 1: Lambda = u_Lambda[0] logdet_Lambda = u_Lambda[1] m0 = linalg.inner(m[0], Lambda, ndim=2 * self._moments.ndim) m1 = m[1] * np.prod(self._moments.shape) return [m0, m1] raise IndexError()
def _compute_message_to_parent(self, index, m, u_Lambda, u_alpha): if index == 0: alpha = misc.add_trailing_axes(u_alpha[0], 2*self._moments.ndim) logalpha = u_alpha[1] m0 = m[0] * alpha m1 = m[1] return [m0, m1] if index == 1: Lambda = u_Lambda[0] logdet_Lambda = u_Lambda[1] m0 = linalg.inner(m[0], Lambda, ndim=2*self._moments.ndim) m1 = m[1] * np.prod(self._moments.shape) return [m0, m1] raise IndexError()
def compute_fixed_moments(self, Lambda, gradient=None): """ Compute moments for fixed x. """ L = linalg.chol(Lambda, ndim=self.ndim) ldet = linalg.chol_logdet(L, ndim=self.ndim) u = [Lambda, ldet] if gradient is None: return u du0 = gradient[0] du1 = ( misc.add_trailing_axes(gradient[1], 2*self.ndim) * linalg.chol_inv(L, ndim=self.ndim) ) du = du0 + du1 return (u, du)
def lower_bound_contribution(self, gradient=False): # Compute E[ log p(X|parents) - log q(X) ] over q(X)q(parents) # Messages from parents #u_parents = [parent.message_to_child() for parent in self.parents] u_parents = self._message_from_parents() phi = self._distribution.compute_phi_from_parents(*u_parents) # G from parents L = self._distribution.compute_cgf_from_parents(*u_parents) # L = g # G for unobserved variables (ignored variables are handled # properly automatically) latent_mask = np.logical_not(self.observed) #latent_mask = np.logical_and(self.mask, np.logical_not(self.observed)) # F for observed, G for latent L = L + np.where(self.observed, self.f, -self.g) for (phi_p, phi_q, u_q, dims) in zip(phi, self.phi, self.u, self.dims): # Form a mask which puts observed variables to zero and # broadcasts properly latent_mask_i = misc.add_trailing_axes( misc.add_leading_axes( latent_mask, len(self.plates) - np.ndim(latent_mask)), len(dims)) axis_sum = tuple(range(-len(dims),0)) # Compute the term phi_q = np.where(latent_mask_i, phi_q, 0) # TODO/FIXME: Use einsum here? Z = np.sum((phi_p-phi_q) * u_q, axis=axis_sum) L = L + Z return (np.sum(np.where(self.mask, L, 0)) * self._plate_multiplier(self.plates, np.shape(L), np.shape(self.mask)))
def lower_bound_contribution(self, gradient=False, ignore_masked=True): r"""Compute E[ log p(X|parents) - log q(X) ] If deterministic annealing is used, the term E[ -log q(X) ] is divided by the anneling coefficient. That is, phi and cgf of q are multiplied by the temperature (inverse annealing coefficient). """ # Annealing temperature T = 1 / self.annealing # Messages from parents u_parents = self._message_from_parents() phi = self._distribution.compute_phi_from_parents(*u_parents) # G from parents L = self._distribution.compute_cgf_from_parents(*u_parents) # G for unobserved variables (ignored variables are handled properly # automatically) latent_mask = np.logical_not(self.observed) # G and F if np.all(self.observed): z = np.nan elif T == 1: z = -self.g else: z = -T * self.g ## TRIED THIS BUT IT WAS WRONG: ## z = -T * self.g + (1-T) * self.f ## if np.any(np.isnan(self.f)): ## warnings.warn("F(x) not implemented for node %s. This " ## "is required for annealed lower bound " ## "computation." % self.__class__.__name__) ## ## It was wrong because the optimal q distribution has f which is ## weighted by 1/T and here the f of q is weighted by T so the ## total weight is 1, thus it cancels out with f of p. L = L + np.where(self.observed, self.f, z) for (phi_p, phi_q, u_q, dims) in zip(phi, self.phi, self.u, self.dims): # Form a mask which puts observed variables to zero and # broadcasts properly latent_mask_i = misc.add_trailing_axes( misc.add_leading_axes(latent_mask, len(self.plates) - np.ndim(latent_mask)), len(dims)) axis_sum = tuple(range(-len(dims), 0)) # Compute the term phi_q = np.where(latent_mask_i, phi_q, 0) # Apply annealing phi_diff = phi_p - T * phi_q # Handle 0 * -inf phi_diff = np.where(u_q != 0, phi_diff, 0) # TODO/FIXME: Use einsum here? Z = np.sum(phi_diff * u_q, axis=axis_sum) L = L + Z if ignore_masked: return (np.sum(np.where(self.mask, L, 0)) * self.broadcasting_multiplier(self.plates, np.shape(L), np.shape(self.mask)) * np.prod(self.plates_multiplier)) else: return (np.sum(L) * self.broadcasting_multiplier(self.plates, np.shape(L)) * np.prod(self.plates_multiplier))
def lower_bound_contribution(self, gradient=False, ignore_masked=True): r"""Compute E[ log p(X|parents) - log q(X) ] If deterministic annealing is used, the term E[ -log q(X) ] is divided by the anneling coefficient. That is, phi and cgf of q are multiplied by the temperature (inverse annealing coefficient). """ # Annealing temperature T = 1 / self.annealing # Messages from parents u_parents = self._message_from_parents() phi = self._distribution.compute_phi_from_parents(*u_parents) # G from parents L = self._distribution.compute_cgf_from_parents(*u_parents) # G for unobserved variables (ignored variables are handled properly # automatically) latent_mask = np.logical_not(self.observed) # G and F if np.all(self.observed): z = np.nan elif T == 1: z = -self.g else: z = -T * self.g ## TRIED THIS BUT IT WAS WRONG: ## z = -T * self.g + (1-T) * self.f ## if np.any(np.isnan(self.f)): ## warnings.warn("F(x) not implemented for node %s. This " ## "is required for annealed lower bound " ## "computation." % self.__class__.__name__) ## ## It was wrong because the optimal q distribution has f which is ## weighted by 1/T and here the f of q is weighted by T so the ## total weight is 1, thus it cancels out with f of p. L = L + np.where(self.observed, self.f, z) for (phi_p, phi_q, u_q, dims) in zip(phi, self.phi, self.u, self.dims): # Form a mask which puts observed variables to zero and # broadcasts properly latent_mask_i = misc.add_trailing_axes( misc.add_leading_axes( latent_mask, len(self.plates) - np.ndim(latent_mask)), len(dims)) axis_sum = tuple(range(-len(dims),0)) # Compute the term phi_q = np.where(latent_mask_i, phi_q, 0) # Apply annealing # TODO/FIXME: Use einsum here? Z = np.sum((phi_p-T*phi_q) * u_q, axis=axis_sum) L = L + Z if ignore_masked: return (np.sum(np.where(self.mask, L, 0)) * self.broadcasting_multiplier(self.plates, np.shape(L), np.shape(self.mask)) * np.prod(self.plates_multiplier)) else: return (np.sum(L) * self.broadcasting_multiplier(self.plates, np.shape(L)) * np.prod(self.plates_multiplier))
def _message_to_parent(self, index): # Compute the message, check plates, apply mask and sum over some plates if index >= len(self.parents): raise ValueError("Parent index larger than the number of parents") # Compute the message and mask (m, mask) = self._get_message_and_mask_to_parent(index) mask = misc.squeeze(mask) # Plates in the mask plates_mask = np.shape(mask) # The parent we're sending the message to parent = self.parents[index] # Plates with respect to the parent plates_self = self._plates_to_parent(index) # Plate multiplier of the parent multiplier_parent = self._plates_multiplier_from_parent(index) # Check if m is a logpdf function (for black-box variational inference) if callable(m): def m_function(*args): lpdf = m(*args) # Log pdf only contains plate axes! plates_m = np.shape(lpdf) r = (self.broadcasting_multiplier(plates_self, plates_m, plates_mask, parent.plates) * self.broadcasting_multiplier(self.plates_multiplier, multiplier_parent)) axes_msg = misc.axes_to_collapse(plates_m, parent.plates) m[i] = misc.sum_multiply(mask_i, m[i], r, axis=axes_msg, keepdims=True) # Remove leading singular plates if the parent does not have # those plate axes. m[i] = misc.squeeze_to_dim(m[i], len(shape_parent)) return m_function raise NotImplementedError() # Compact the message to a proper shape for i in range(len(m)): # Empty messages are given as None. We can ignore those. if m[i] is not None: try: r = self.broadcasting_multiplier(self.plates_multiplier, multiplier_parent) except: raise ValueError("The plate multipliers are incompatible. " "This node (%s) has %s and parent[%d] " "(%s) has %s" % (self.name, self.plates_multiplier, index, parent.name, multiplier_parent)) ndim = len(parent.dims[i]) # Source and target shapes if ndim > 0: dims = misc.broadcasted_shape( np.shape(m[i])[-ndim:], parent.dims[i]) from_shape = plates_self + dims else: from_shape = plates_self to_shape = parent.get_shape(i) # Add variable axes to the mask mask_i = misc.add_trailing_axes(mask, ndim) # Apply mask and sum plate axes as necessary (and apply plate # multiplier) m[i] = r * misc.sum_multiply_to_plates(m[i], mask_i, to_plates=to_shape, from_plates=from_shape, ndim=0) return m
def compute_message_to_parent(self, parent, index, u, *u_parents): """ Compute the message to a parent node. """ if index == 0: # Shape(phi) = [Nn,..,K,..,N0,Dd,..,D0] # Shape(L) = [Nn,..,K,..,N0] # Shape(u) = [Nn,..,N0,Dd,..,D0] # Shape(result) = [Nn,..,N0,K] # Compute g: # Shape(g) = [Nn,..,K,..,N0] g = self.distribution.compute_cgf_from_parents(*(u_parents[1:])) # Reshape(g): # Shape(g) = [Nn,..,N0,K] if np.ndim(g) < abs(self.cluster_plate): # Not enough axes, just add the cluster plate axis g = np.expand_dims(g, -1) else: # Move the cluster plate axis g = misc.moveaxis(g, self.cluster_plate, -1) # Compute phi: # Shape(phi) = [Nn,..,K,..,N0,Dd,..,D0] phi = self.distribution.compute_phi_from_parents(*(u_parents[1:])) # Move phi axis: # Shape(phi) = [Nn,..,N0,K,Dd,..,D0] for ind in range(len(phi)): if self.cluster_plate < 0: axis_from = self.cluster_plate-self.ndims[ind] else: raise RuntimeError("Cluster plate axis must be negative") axis_to = -1-self.ndims[ind] if np.ndim(phi[ind]) >= abs(axis_from): # Cluster plate axis exists, move it to the correct position phi[ind] = misc.moveaxis(phi[ind], axis_from, axis_to) else: # No cluster plate axis, just add a new axis to the correct # position, if phi has something on that axis if np.ndim(phi[ind]) >= abs(axis_to): phi[ind] = np.expand_dims(phi[ind], axis=axis_to) # Reshape u: # Shape(u) = [Nn,..,N0,1,Dd,..,D0] u_self = list() for ind in range(len(u)): u_self.append(np.expand_dims(u[ind], axis=(-1-self.ndims[ind]))) # Compute logpdf: # Shape(L) = [Nn,..,N0,K] L = self.distribution.compute_logpdf(u_self, phi, g, 0, self.ndims) # Sum over other than the cluster dimensions? No! # Hmm.. I think the message passing method will do # that automatically m = [L] return m elif index >= 1: # Parent index for the distribution used for the # mixture. index_for_parent = index - 1 # Reshape u: # Shape(u) = [Nn,..1,..,N0,Dd,..,D0] u_self = list() for ind in range(len(u)): if self.cluster_plate < 0: cluster_axis = self.cluster_plate - self.ndims[ind] else: raise ValueError("Cluster plate axis must be negative") u_self.append(np.expand_dims(u[ind], axis=cluster_axis)) # Message from the mixed distribution m = self.distribution.compute_message_to_parent(parent, index_for_parent, u_self, *(u_parents[1:])) # Note: The cluster assignment probabilities can be considered as # weights to plate elements. These weights need to mapped properly # via the plate mapping of self.distribution. Otherwise, nested # mixtures won't work, or possibly not any distribution that does # something to the plates. Thus, use compute_weights_to_parent to # compute the transformations to the weight array properly. # # See issue #39 for more details. # Compute weights (i.e., cluster assignment probabilities) and map # the plates properly. p = misc.atleast_nd(u_parents[0][0], abs(self.cluster_plate)) p = misc.moveaxis(p, -1, self.cluster_plate) p = self.distribution.compute_weights_to_parent( index_for_parent, p, ) # Weigh the elements in the message array m = [mi * misc.add_trailing_axes(p, ndim) #for (mi, ndim) in zip(m, self.ndims)] for (mi, ndim) in zip(m, self.ndims_parents[index_for_parent])] return m
def compute_phi_from_parents(self, *u_parents, mask=True): """ Compute the natural parameter vector given parent moments. """ # Compute weighted average of the parameters # Cluster parameters Phi = self.distribution.compute_phi_from_parents(*(u_parents[1:])) # Contributions/weights/probabilities P = u_parents[0][0] phi = list() nans = False for ind in range(len(Phi)): # Compute element-wise product and then sum over K clusters. # Note that the dimensions aren't perfectly aligned because # the cluster dimension (K) may be arbitrary for phi, and phi # also has dimensions (Dd,..,D0) of the parameters. # Shape(phi) = [Nn,..,K,..,N0,Dd,..,D0] # Shape(p) = [Nn,..,N0,K] # Shape(result) = [Nn,..,N0,Dd,..,D0] # General broadcasting rules apply for Nn,..,N0, that is, # preceding dimensions may be missing or dimension may be # equal to one. Probably, shape(phi) has lots of missing # dimensions and/or dimensions that are one. if self.cluster_plate < 0: cluster_axis = self.cluster_plate - self.ndims[ind] else: raise RuntimeError("Cluster plate should be negative") # Move cluster axis to the last: # Shape(phi) = [Nn,..,N0,Dd,..,D0,K] if np.ndim(Phi[ind]) >= abs(cluster_axis): phi.append(misc.moveaxis(Phi[ind], cluster_axis, -1)) else: phi.append(Phi[ind][...,None]) # Add axes to p: # Shape(p) = [Nn,..,N0,K,1,..,1] p = misc.add_trailing_axes(P, self.ndims[ind]) # Move cluster axis to the last: # Shape(p) = [Nn,..,N0,1,..,1,K] p = misc.moveaxis(p, -(self.ndims[ind]+1), -1) # Handle zero probability cases. This avoids nans when p=0 and # phi=inf. phi[ind] = np.where(p != 0, phi[ind], 0) # Now the shapes broadcast perfectly and we can sum # p*phi over the last axis: # Shape(result) = [Nn,..,N0,Dd,..,D0] phi[ind] = misc.sum_product(p, phi[ind], axes_to_sum=-1) if np.any(np.isnan(phi[ind])): nans = True if nans: warnings.warn("The natural parameters of mixture distribution " "contain nans. This may happen if you use fixed " "parameters in your model. Technically, one possible " "reason is that the cluster assignment probability " "for some element is zero (p=0) and the natural " "parameter of that cluster is -inf, thus " "0*(-inf)=nan. Solution: Use parameters that assign " "non-zero probabilities for the whole domain.") return phi
def compute_message_to_parent(self, parent, index, u, *u_parents): """ Compute the message to a parent node. """ if index == 0: # Shape(phi) = [Nn,..,K,..,N0,Dd,..,D0] # Shape(L) = [Nn,..,K,..,N0] # Shape(u) = [Nn,..,N0,Dd,..,D0] # Shape(result) = [Nn,..,N0,K] # Compute g: # Shape(g) = [Nn,..,K,..,N0] g = self.distribution.compute_cgf_from_parents(*(u_parents[1:])) # Reshape(g): # Shape(g) = [Nn,..,N0,K] if np.ndim(g) < abs(self.cluster_plate): # Not enough axes, just add the cluster plate axis g = np.expand_dims(g, -1) else: # Move the cluster plate axis g = misc.moveaxis(g, self.cluster_plate, -1) # Compute phi: # Shape(phi) = [Nn,..,K,..,N0,Dd,..,D0] phi = self.distribution.compute_phi_from_parents(*(u_parents[1:])) # Move phi axis: # Shape(phi) = [Nn,..,N0,K,Dd,..,D0] for ind in range(len(phi)): if self.cluster_plate < 0: axis_from = self.cluster_plate - self.ndims[ind] else: raise RuntimeError("Cluster plate axis must be negative") axis_to = -1 - self.ndims[ind] if np.ndim(phi[ind]) >= abs(axis_from): # Cluster plate axis exists, move it to the correct position phi[ind] = misc.moveaxis(phi[ind], axis_from, axis_to) else: # No cluster plate axis, just add a new axis to the correct # position, if phi has something on that axis if np.ndim(phi[ind]) >= abs(axis_to): phi[ind] = np.expand_dims(phi[ind], axis=axis_to) # Reshape u: # Shape(u) = [Nn,..,N0,1,Dd,..,D0] u_self = list() for ind in range(len(u)): u_self.append( np.expand_dims(u[ind], axis=(-1 - self.ndims[ind]))) # Compute logpdf: # Shape(L) = [Nn,..,N0,K] L = self.distribution.compute_logpdf(u_self, phi, g, 0, self.ndims) # Sum over other than the cluster dimensions? No! # Hmm.. I think the message passing method will do # that automatically m = [L] return m elif index >= 1: # Parent index for the distribution used for the # mixture. index_for_parent = index - 1 # Reshape u: # Shape(u) = [Nn,..1,..,N0,Dd,..,D0] u_self = list() for ind in range(len(u)): if self.cluster_plate < 0: cluster_axis = self.cluster_plate - self.ndims[ind] else: raise ValueError("Cluster plate axis must be negative") u_self.append(np.expand_dims(u[ind], axis=cluster_axis)) # Message from the mixed distribution m = self.distribution.compute_message_to_parent( parent, index_for_parent, u_self, *(u_parents[1:])) # Note: The cluster assignment probabilities can be considered as # weights to plate elements. These weights need to mapped properly # via the plate mapping of self.distribution. Otherwise, nested # mixtures won't work, or possibly not any distribution that does # something to the plates. Thus, use compute_weights_to_parent to # compute the transformations to the weight array properly. # # See issue #39 for more details. # Compute weights (i.e., cluster assignment probabilities) and map # the plates properly. p = misc.atleast_nd(u_parents[0][0], abs(self.cluster_plate)) p = misc.moveaxis(p, -1, self.cluster_plate) p = self.distribution.compute_weights_to_parent( index_for_parent, p, ) # Weigh the elements in the message array m = [ mi * misc.add_trailing_axes(p, ndim) #for (mi, ndim) in zip(m, self.ndims)] for (mi, ndim) in zip(m, self.ndims_parents[index_for_parent]) ] return m
def compute_phi_from_parents(self, *u_parents, mask=True): """ Compute the natural parameter vector given parent moments. """ # Compute weighted average of the parameters # Cluster parameters Phi = self.distribution.compute_phi_from_parents(*(u_parents[1:])) # Contributions/weights/probabilities P = u_parents[0][0] phi = list() nans = False for ind in range(len(Phi)): # Compute element-wise product and then sum over K clusters. # Note that the dimensions aren't perfectly aligned because # the cluster dimension (K) may be arbitrary for phi, and phi # also has dimensions (Dd,..,D0) of the parameters. # Shape(phi) = [Nn,..,K,..,N0,Dd,..,D0] # Shape(p) = [Nn,..,N0,K] # Shape(result) = [Nn,..,N0,Dd,..,D0] # General broadcasting rules apply for Nn,..,N0, that is, # preceding dimensions may be missing or dimension may be # equal to one. Probably, shape(phi) has lots of missing # dimensions and/or dimensions that are one. if self.cluster_plate < 0: cluster_axis = self.cluster_plate - self.ndims[ind] else: raise RuntimeError("Cluster plate should be negative") # Move cluster axis to the last: # Shape(phi) = [Nn,..,N0,Dd,..,D0,K] if np.ndim(Phi[ind]) >= abs(cluster_axis): phi.append(misc.moveaxis(Phi[ind], cluster_axis, -1)) else: phi.append(Phi[ind][..., None]) # Add axes to p: # Shape(p) = [Nn,..,N0,K,1,..,1] p = misc.add_trailing_axes(P, self.ndims[ind]) # Move cluster axis to the last: # Shape(p) = [Nn,..,N0,1,..,1,K] p = misc.moveaxis(p, -(self.ndims[ind] + 1), -1) # Handle zero probability cases. This avoids nans when p=0 and # phi=inf. phi[ind] = np.where(p != 0, phi[ind], 0) # Now the shapes broadcast perfectly and we can sum # p*phi over the last axis: # Shape(result) = [Nn,..,N0,Dd,..,D0] phi[ind] = misc.sum_product(p, phi[ind], axes_to_sum=-1) if np.any(np.isnan(phi[ind])): nans = True if nans: warnings.warn( "The natural parameters of mixture distribution " "contain nans. This may happen if you use fixed " "parameters in your model. Technically, one possible " "reason is that the cluster assignment probability " "for some element is zero (p=0) and the natural " "parameter of that cluster is -inf, thus " "0*(-inf)=nan. Solution: Use parameters that assign " "non-zero probabilities for the whole domain.") return phi
def compute_message_to_parent(self, parent, index, u, *u_parents): """ Compute the message to a parent node. """ if index == 0: # Shape(phi) = [Nn,..,K,..,N0,Dd,..,D0] # Shape(L) = [Nn,..,K,..,N0] # Shape(u) = [Nn,..,N0,Dd,..,D0] # Shape(result) = [Nn,..,N0,K] # Compute g: # Shape(g) = [Nn,..,K,..,N0] g = self.distribution.compute_cgf_from_parents(*(u_parents[1:])) # Reshape(g): # Shape(g) = [Nn,..,N0,K] if np.ndim(g) < abs(self.cluster_plate): # Not enough axes, just add the cluster plate axis g = np.expand_dims(g, -1) else: # Move the cluster plate axis g = misc.moveaxis(g, self.cluster_plate, -1) # Compute phi: # Shape(phi) = [Nn,..,K,..,N0,Dd,..,D0] phi = self.distribution.compute_phi_from_parents(*(u_parents[1:])) # Move phi axis: # Shape(phi) = [Nn,..,N0,K,Dd,..,D0] for ind in range(len(phi)): if self.cluster_plate < 0: axis_from = self.cluster_plate - self.ndims[ind] else: raise RuntimeError("Cluster plate axis must be negative") axis_to = -1 - self.ndims[ind] if np.ndim(phi[ind]) >= abs(axis_from): # Cluster plate axis exists, move it to the correct position phi[ind] = misc.moveaxis(phi[ind], axis_from, axis_to) else: # No cluster plate axis, just add a new axis to the correct # position, if phi has something on that axis if np.ndim(phi[ind]) >= abs(axis_to): phi[ind] = np.expand_dims(phi[ind], axis=axis_to) # Reshape u: # Shape(u) = [Nn,..,N0,1,Dd,..,D0] u_self = list() for ind in range(len(u)): u_self.append( np.expand_dims(u[ind], axis=(-1 - self.ndims[ind]))) # Compute logpdf: # Shape(L) = [Nn,..,N0,K] L = self.distribution.compute_logpdf(u_self, phi, g, 0, self.ndims) # Sum over other than the cluster dimensions? No! # Hmm.. I think the message passing method will do # that automatically m = [L] return m elif index >= 1: # Parent index for the distribution used for the # mixture. index = index - 1 # Reshape u: # Shape(u) = [Nn,..1,..,N0,Dd,..,D0] u_self = list() for ind in range(len(u)): if self.cluster_plate < 0: cluster_axis = self.cluster_plate - self.ndims[ind] else: cluster_axis = self.cluster_plate u_self.append(np.expand_dims(u[ind], axis=cluster_axis)) # Message from the mixed distribution m = self.distribution.compute_message_to_parent( parent, index, u_self, *(u_parents[1:])) # Weigh the messages with the responsibilities for i in range(len(m)): # Shape(m) = [Nn,..,K,..,N0,Dd,..,D0] # Shape(p) = [Nn,..,N0,K] # Shape(result) = [Nn,..,K,..,N0,Dd,..,D0] # Number of axes for the variable dimensions for # the parent message. D = self.ndims_parents[index][i] # Responsibilities for clusters are the first # parent's first moment: # Shape(p) = [Nn,..,N0,K] p = u_parents[0][0] # Move the cluster axis to the proper place: # Shape(p) = [Nn,..,K,..,N0] p = misc.atleast_nd(p, abs(self.cluster_plate)) p = misc.moveaxis(p, -1, self.cluster_plate) # Add axes for variable dimensions to the contributions # Shape(p) = [Nn,..,K,..,N0,1,..,1] p = misc.add_trailing_axes(p, D) if self.cluster_plate < 0: # Add the variable dimensions cluster_axis = self.cluster_plate - D # Add axis for clusters: # Shape(m) = [Nn,..,1,..,N0,Dd,..,D0] #m[i] = np.expand_dims(m[i], axis=cluster_axis) # # TODO: You could do summing here already so that # you wouldn't compute huge matrices as # intermediate result. Use einsum. # Compute the message contributions for each # cluster: # Shape(result) = [Nn,..,K,..,N0,Dd,..,D0] m[i] = m[i] * p return m
def compute_message_to_parent(self, parent, index, u, *u_parents): """ Compute the message to a parent node. """ if index == 0: # Shape(phi) = [Nn,..,K,..,N0,Dd,..,D0] # Shape(L) = [Nn,..,K,..,N0] # Shape(u) = [Nn,..,N0,Dd,..,D0] # Shape(result) = [Nn,..,N0,K] # Compute g: # Shape(g) = [Nn,..,K,..,N0] g = self.distribution.compute_cgf_from_parents(*(u_parents[1:])) # Reshape(g): # Shape(g) = [Nn,..,N0,K] if np.ndim(g) < abs(self.cluster_plate): # Not enough axes, just add the cluster plate axis g = np.expand_dims(g, -1) else: # Move the cluster plate axis g = misc.moveaxis(g, self.cluster_plate, -1) # Compute phi: # Shape(phi) = [Nn,..,K,..,N0,Dd,..,D0] phi = self.distribution.compute_phi_from_parents(*(u_parents[1:])) # Move phi axis: # Shape(phi) = [Nn,..,N0,K,Dd,..,D0] for ind in range(len(phi)): if self.cluster_plate < 0: axis_from = self.cluster_plate - self.ndims[ind] else: raise RuntimeError("Cluster plate axis must be negative") axis_to = -1 - self.ndims[ind] if np.ndim(phi[ind]) >= abs(axis_from): # Cluster plate axis exists, move it to the correct position phi[ind] = misc.moveaxis(phi[ind], axis_from, axis_to) else: # No cluster plate axis, just add a new axis to the correct # position, if phi has something on that axis if np.ndim(phi[ind]) >= abs(axis_to): phi[ind] = np.expand_dims(phi[ind], axis=axis_to) # Reshape u: # Shape(u) = [Nn,..,N0,1,Dd,..,D0] u_self = list() for ind in range(len(u)): u_self.append(np.expand_dims(u[ind], axis=(-1 - self.ndims[ind]))) # Compute logpdf: # Shape(L) = [Nn,..,N0,K] L = self.distribution.compute_logpdf(u_self, phi, g, 0, self.ndims) # Sum over other than the cluster dimensions? No! # Hmm.. I think the message passing method will do # that automatically m = [L] return m elif index >= 1: # Parent index for the distribution used for the # mixture. index = index - 1 # Reshape u: # Shape(u) = [Nn,..1,..,N0,Dd,..,D0] u_self = list() for ind in range(len(u)): if self.cluster_plate < 0: cluster_axis = self.cluster_plate - self.ndims[ind] else: cluster_axis = self.cluster_plate u_self.append(np.expand_dims(u[ind], axis=cluster_axis)) # Message from the mixed distribution m = self.distribution.compute_message_to_parent(parent, index, u_self, *(u_parents[1:])) # Weigh the messages with the responsibilities for i in range(len(m)): # Shape(m) = [Nn,..,K,..,N0,Dd,..,D0] # Shape(p) = [Nn,..,N0,K] # Shape(result) = [Nn,..,K,..,N0,Dd,..,D0] # Number of axes for the variable dimensions for # the parent message. D = self.ndims_parents[index][i] # Responsibilities for clusters are the first # parent's first moment: # Shape(p) = [Nn,..,N0,K] p = u_parents[0][0] # Move the cluster axis to the proper place: # Shape(p) = [Nn,..,K,..,N0] p = misc.atleast_nd(p, abs(self.cluster_plate)) p = misc.moveaxis(p, -1, self.cluster_plate) # Add axes for variable dimensions to the contributions # Shape(p) = [Nn,..,K,..,N0,1,..,1] p = misc.add_trailing_axes(p, D) if self.cluster_plate < 0: # Add the variable dimensions cluster_axis = self.cluster_plate - D # Add axis for clusters: # Shape(m) = [Nn,..,1,..,N0,Dd,..,D0] # m[i] = np.expand_dims(m[i], axis=cluster_axis) # # TODO: You could do summing here already so that # you wouldn't compute huge matrices as # intermediate result. Use einsum. # Compute the message contributions for each # cluster: # Shape(result) = [Nn,..,K,..,N0,Dd,..,D0] m[i] = m[i] * p return m
def _message_to_parent(self, index): # Compute the message, check plates, apply mask and sum over some plates if index >= len(self.parents): raise ValueError("Parent index larger than the number of parents") # Compute the message and mask (m, mask) = self._get_message_and_mask_to_parent(index) mask = misc.squeeze(mask) # Plates in the mask plates_mask = np.shape(mask) # The parent we're sending the message to parent = self.parents[index] # Plates with respect to the parent plates_self = self._plates_to_parent(index) # Plate multiplier of the parent multiplier_parent = self._plates_multiplier_from_parent(index) # Check if m is a logpdf function (for black-box variational inference) if callable(m): def m_function(*args): lpdf = m(*args) # Log pdf only contains plate axes! plates_m = np.shape(lpdf) r = (self.broadcasting_multiplier(plates_self, plates_m, plates_mask, parent.plates) * self.broadcasting_multiplier(self.plates_multiplier, multiplier_parent)) axes_msg = misc.axes_to_collapse(plates_m, parent.plates) m[i] = misc.sum_multiply(mask_i, m[i], r, axis=axes_msg, keepdims=True) # Remove leading singular plates if the parent does not have # those plate axes. m[i] = misc.squeeze_to_dim(m[i], len(shape_parent)) return m_function raise NotImplementedError() # Compact the message to a proper shape for i in range(len(m)): # Empty messages are given as None. We can ignore those. if m[i] is not None: try: r = self.broadcasting_multiplier(self.plates_multiplier, multiplier_parent) except: raise ValueError("The plate multipliers are incompatible. " "This node (%s) has %s and parent[%d] " "(%s) has %s" % (self.name, self.plates_multiplier, index, parent.name, multiplier_parent)) ndim = len(parent.dims[i]) # Source and target shapes if ndim > 0: dims = misc.broadcasted_shape(np.shape(m[i])[-ndim:], parent.dims[i]) from_shape = plates_self + dims else: from_shape = plates_self to_shape = parent.get_shape(i) # Add variable axes to the mask mask_i = misc.add_trailing_axes(mask, ndim) # Apply mask and sum plate axes as necessary (and apply plate # multiplier) m[i] = r * misc.sum_multiply_to_plates(m[i], mask_i, to_plates=to_shape, from_plates=from_shape, ndim=0) return m
def compute_message_to_parent(self, parent, index, u, *u_parents): """ Compute the message to a parent node. """ if index == 0: # Shape(phi) = [Nn,..,K,..,N0,Dd,..,D0] # Shape(L) = [Nn,..,K,..,N0] # Shape(u) = [Nn,..,N0,Dd,..,D0] # Shape(result) = [Nn,..,N0,K] # Compute g: # Shape(g) = [Nn,..,K,..,N0] g = self.raw_distribution.compute_cgf_from_parents( *(u_parents[1:])) # Reshape(g): # Shape(g) = [Nn,..,N0,K] if np.ndim(g) < abs(self.cluster_plate): # Not enough axes, just add the cluster plate axis g = np.expand_dims(g, -1) else: # Move the cluster plate axis g = misc.moveaxis(g, self.cluster_plate, -1) # Compute phi: # Shape(phi) = [Nn,..,K,..,N0,Dd,..,D0] phi = self.raw_distribution.compute_phi_from_parents( *(u_parents[1:])) # Reshape u: # Shape(u) = = [Nn,..,1,..,N0,Dd,..,D0] u_reshaped = [ np.expand_dims(ui, self.cluster_plate - ndimi) if np.ndim(ui) >= abs(self.cluster_plate - ndimi) else ui for (ui, ndimi) in zip(u, self.ndims) ] # Compute logpdf: # Shape(L) = [Nn,..,K,..,N0] L = self.raw_distribution.compute_logpdf( u_reshaped, phi, g, 0, self.ndims, ) # Move axis: # Shape(L) = [Nn,..,N0,K] L = np.moveaxis(L, self.cluster_plate, -1) m = [L] return m elif index >= 1: # Parent index for the distribution used for the # mixture. index_for_parent = index - 1 # Reshape u: # Shape(u_self) = [Nn,..1,..,N0,Dd,..,D0] u_self = list() for ind in range(len(u)): if self.cluster_plate < 0: cluster_axis = self.cluster_plate - self.ndims[ind] else: raise ValueError("Cluster plate axis must be negative") u_self.append(np.expand_dims(u[ind], axis=cluster_axis)) # Message from the mixed distribution # Shape(m) = [Nn,..,K,..,N0,Dd,..,D0] m = self.raw_distribution.compute_message_to_parent( parent, index_for_parent, u_self, *(u_parents[1:])) # Note: The cluster assignment probabilities can be considered as # weights to plate elements. These weights need to mapped properly # via the plate mapping of self.distribution. Otherwise, nested # mixtures won't work, or possibly not any distribution that does # something to the plates. Thus, use compute_weights_to_parent to # compute the transformations to the weight array properly. # # See issue #39 for more details. # Compute weights (i.e., cluster assignment probabilities) and map # the plates properly. # Shape(p) = [Nn,..,K,..,N0] p = misc.atleast_nd(u_parents[0][0], abs(self.cluster_plate)) p = misc.moveaxis(p, -1, self.cluster_plate) p = self.raw_distribution.compute_weights_to_parent( index_for_parent, p, ) # Weigh the elements in the message array # # TODO/FIXME: This may result in huge intermediate arrays. Need to # use einsum! m = [ mi * misc.add_trailing_axes(p, ndim) #for (mi, ndim) in zip(m, self.ndims)] for (mi, ndim) in zip(m, self.ndims_parents[index_for_parent]) ] return m