Exemplo n.º 1
0
    def _compute_message_to_parent(self, index, m_child, u_Z, u_X):
        """
        """
        if index == 0:
            m0 = 0
            # Compute Child * X, sum over variable axes and move the gated axis
            # to be the last.  Need to do some shape changing in order to make
            # Child and X to broadcast properly.
            for i in range(len(m_child)):
                ndim = len(self.dims[i])
                c = m_child[i][..., None]
                c = misc.moveaxis(c, -1, -ndim - 1)
                gated_axis = self.gated_plate - ndim
                x = u_X[i]
                if np.ndim(x) < abs(gated_axis):
                    x = np.expand_dims(x, -ndim - 1)
                else:
                    x = misc.moveaxis(x, gated_axis, -ndim - 1)
                axes = tuple(range(-ndim, 0))
                m0 = m0 + misc.sum_product(c, x, axes_to_sum=axes)

            # Make sure the variable axis does not use broadcasting
            m0 = m0 * np.ones(self.K)

            # Send the message
            m = [m0]
            return m

        elif index == 1:

            m = []
            for i in range(len(m_child)):
                # Make the moments of Z and the message from children
                # broadcastable. The gated plate is handled as the last axis in
                # the arrays and moved to the correct position at the end.

                # Add variable axes to Z moments
                ndim = len(self.dims[i])
                z = misc.add_trailing_axes(u_Z[0], ndim)
                z = misc.moveaxis(z, -ndim - 1, -1)
                # Axis index of the gated plate
                gated_axis = self.gated_plate - ndim
                # Add the gate axis to the message from the children
                c = misc.add_trailing_axes(m_child[i], 1)
                # Compute the message to parent
                mi = z * c
                # Add extra axes if necessary
                if np.ndim(mi) < abs(gated_axis):
                    mi = misc.add_leading_axes(mi,
                                               abs(gated_axis) - np.ndim(mi))
                # Move the axis to the correct position
                mi = misc.moveaxis(mi, -1, gated_axis)
                m.append(mi)

            return m

        else:
            raise ValueError("Invalid parent index")
Exemplo n.º 2
0
    def _compute_message_to_parent(self, index, m_child, u_Z, u_X):
        """
        """
        if index == 0:
            m0 = 0
            # Compute Child * X, sum over variable axes and move the gated axis
            # to be the last.  Need to do some shape changing in order to make
            # Child and X to broadcast properly.
            for i in range(len(m_child)):
                ndim = len(self.dims[i])
                c = m_child[i][...,None]
                c = misc.moveaxis(c, -1, -ndim-1)
                gated_axis = self.gated_plate - ndim
                x = u_X[i]
                if np.ndim(x) < abs(gated_axis):
                    x = np.expand_dims(x, -ndim-1)
                else:
                    x = misc.moveaxis(x, gated_axis, -ndim-1)
                axes = tuple(range(-ndim, 0))
                m0 = m0 + misc.sum_product(c, x, axes_to_sum=axes)

            # Make sure the variable axis does not use broadcasting
            m0 = m0 * np.ones(self.K)

            # Send the message
            m = [m0]
            return m

        elif index == 1:

            m = []
            for i in range(len(m_child)):
                # Make the moments of Z and the message from children
                # broadcastable. The gated plate is handled as the last axis in
                # the arrays and moved to the correct position at the end.

                # Add variable axes to Z moments
                ndim = len(self.dims[i])
                z = misc.add_trailing_axes(u_Z[0], ndim)
                z = misc.moveaxis(z, -ndim-1, -1)
                # Axis index of the gated plate
                gated_axis = self.gated_plate - ndim
                # Add the gate axis to the message from the children
                c = misc.add_trailing_axes(m_child[i], 1)
                # Compute the message to parent
                mi = z * c
                # Add extra axes if necessary
                if np.ndim(mi) < abs(gated_axis):
                    mi = misc.add_leading_axes(mi,
                                                abs(gated_axis) - np.ndim(mi))
                # Move the axis to the correct position
                mi = misc.moveaxis(mi, -1, gated_axis)
                m.append(mi)

            return m

        else:
            raise ValueError("Invalid parent index")
Exemplo n.º 3
0
    def _compute_moments(self, u_Z, u_X):
        """
        """

        u = []
        for i in range(len(u_X)):
            # Make the moments of Z and X broadcastable and move the gated plate
            # to be the last axis in the moments, then sum-product over that
            # axis
            ndim = len(self.dims[i])
            z = misc.add_trailing_axes(u_Z[0], ndim)
            z = misc.moveaxis(z, -ndim - 1, -1)
            gated_axis = self.gated_plate - ndim
            if np.ndim(u_X[i]) < abs(gated_axis):
                x = misc.add_trailing_axes(u_X[i], 1)
            else:
                x = misc.moveaxis(u_X[i], gated_axis, -1)
            ui = misc.sum_product(z, x, axes_to_sum=-1)
            u.append(ui)
        return u
Exemplo n.º 4
0
    def _compute_moments(self, u_Z, u_X):
        """
        """

        u = []
        for i in range(len(u_X)):
            # Make the moments of Z and X broadcastable and move the gated plate
            # to be the last axis in the moments, then sum-product over that
            # axis
            ndim = len(self.dims[i])
            z = misc.add_trailing_axes(u_Z[0], ndim)
            z = misc.moveaxis(z, -ndim-1, -1)
            gated_axis = self.gated_plate - ndim
            if np.ndim(u_X[i]) < abs(gated_axis):
                x = misc.add_trailing_axes(u_X[i], 1)
            else:
                x = misc.moveaxis(u_X[i], gated_axis, -1)
            ui = misc.sum_product(z, x, axes_to_sum=-1)
            u.append(ui)
        return u
Exemplo n.º 5
0
    def _compute_moments(self, u_Lambda, u_alpha):

        Lambda = u_Lambda[0]
        logdet_Lambda = u_Lambda[1]

        alpha = misc.add_trailing_axes(u_alpha[0], 2*self._moments.ndim)
        logalpha = u_alpha[1]

        u0 = Lambda * alpha
        u1 = logdet_Lambda + np.prod(self._moments.shape) * logalpha

        return [u0, u1]
Exemplo n.º 6
0
    def _compute_moments(self, u_Lambda, u_alpha):

        Lambda = u_Lambda[0]
        logdet_Lambda = u_Lambda[1]

        alpha = misc.add_trailing_axes(u_alpha[0], 2 * self._moments.ndim)
        logalpha = u_alpha[1]

        u0 = Lambda * alpha
        u1 = logdet_Lambda + np.prod(self._moments.shape) * logalpha

        return [u0, u1]
Exemplo n.º 7
0
    def _set_moments(self, u, mask=True, broadcast=True):

        self._check_shape(u, broadcast=broadcast)

        # Store the computed moments u but do not change moments for
        # observations, i.e., utilize the mask.
        for ind in range(len(u)):
            # Add axes to the mask for the variable dimensions (mask
            # contains only axes for the plates).
            u_mask = misc.add_trailing_axes(mask, self.ndims[ind])

            # Enlarge self.u[ind] as necessary so that it can store the
            # broadcasted result.
            sh = misc.broadcasted_shape_from_arrays(self.u[ind], u[ind], u_mask)
            self.u[ind] = misc.repeat_to_shape(self.u[ind], sh)

            # TODO/FIXME/BUG: The mask of observations is not used, observations
            # may be overwritten!!! ???
            
            # Hah, this function is used to set the observations! The caller
            # should be careful what mask he uses! If you want to set only
            # latent variables, then use such a mask.
            
            # Use mask to update only unobserved plates and keep the
            # observed as before
            np.copyto(self.u[ind],
                      u[ind],
                      where=u_mask)

            # Make sure u has the correct number of dimensions:
            shape = self.get_shape(ind)
            ndim = len(shape)
            ndim_u = np.ndim(self.u[ind])
            if ndim > ndim_u:
                self.u[ind] = misc.add_leading_axes(u[ind], ndim - ndim_u)
            elif ndim < ndim_u:
                # This should not ever happen because we already checked the
                # shape at the beginning of the function.
                raise RuntimeError(
                    "This error should not happen. Fix shape checking."
                    "The size of the variable %s's %s-th moment "
                    "array is %s which is larger than it should "
                    "be, that is, %s, based on the plates %s and "
                    "dimension %s. Check that you have provided "
                    "plates properly."
                    % (self.name,
                       ind,
                       np.shape(self.u[ind]), 
                       shape,
                       self.plates,
                       self.dims[ind]))
Exemplo n.º 8
0
    def compute_fixed_moments(self, Lambda, gradient=None):
        """ Compute moments for fixed x. """
        L = linalg.chol(Lambda, ndim=self.ndim)
        ldet = linalg.chol_logdet(L, ndim=self.ndim)
        u = [Lambda, ldet]

        if gradient is None:
            return u

        du0 = gradient[0]
        du1 = (misc.add_trailing_axes(gradient[1], 2 * self.ndim) *
               linalg.chol_inv(L, ndim=self.ndim))

        du = du0 + du1

        return (u, du)
Exemplo n.º 9
0
    def _compute_message_to_parent(self, index, m, u_Lambda, u_alpha):

        if index == 0:
            alpha = misc.add_trailing_axes(u_alpha[0], 2 * self._moments.ndim)
            logalpha = u_alpha[1]
            m0 = m[0] * alpha
            m1 = m[1]
            return [m0, m1]

        if index == 1:
            Lambda = u_Lambda[0]
            logdet_Lambda = u_Lambda[1]
            m0 = linalg.inner(m[0], Lambda, ndim=2 * self._moments.ndim)
            m1 = m[1] * np.prod(self._moments.shape)
            return [m0, m1]

        raise IndexError()
Exemplo n.º 10
0
    def _compute_message_to_parent(self, index, m, u_Lambda, u_alpha):

        if index == 0:
            alpha = misc.add_trailing_axes(u_alpha[0], 2*self._moments.ndim)
            logalpha = u_alpha[1]
            m0 = m[0] * alpha
            m1 = m[1]
            return [m0, m1]

        if index == 1:
            Lambda = u_Lambda[0]
            logdet_Lambda = u_Lambda[1]
            m0 = linalg.inner(m[0], Lambda, ndim=2*self._moments.ndim)
            m1 = m[1] * np.prod(self._moments.shape)
            return [m0, m1]

        raise IndexError()
Exemplo n.º 11
0
    def compute_fixed_moments(self, Lambda, gradient=None):
        """ Compute moments for fixed x. """
        L = linalg.chol(Lambda, ndim=self.ndim)
        ldet = linalg.chol_logdet(L, ndim=self.ndim)
        u = [Lambda,
             ldet]

        if gradient is None:
            return u

        du0 = gradient[0]
        du1 = (
            misc.add_trailing_axes(gradient[1], 2*self.ndim)
            * linalg.chol_inv(L, ndim=self.ndim)
        )

        du = du0 + du1

        return (u, du)
Exemplo n.º 12
0
    def lower_bound_contribution(self, gradient=False):
        # Compute E[ log p(X|parents) - log q(X) ] over q(X)q(parents)
        
        # Messages from parents
        #u_parents = [parent.message_to_child() for parent in self.parents]
        u_parents = self._message_from_parents()
        phi = self._distribution.compute_phi_from_parents(*u_parents)
        # G from parents
        L = self._distribution.compute_cgf_from_parents(*u_parents)
        # L = g
        # G for unobserved variables (ignored variables are handled
        # properly automatically)
        latent_mask = np.logical_not(self.observed)
        #latent_mask = np.logical_and(self.mask, np.logical_not(self.observed))
        # F for observed, G for latent
        L = L + np.where(self.observed, self.f, -self.g)
        for (phi_p, phi_q, u_q, dims) in zip(phi, self.phi, self.u, self.dims):
            # Form a mask which puts observed variables to zero and
            # broadcasts properly
            latent_mask_i = misc.add_trailing_axes(
                                misc.add_leading_axes(
                                    latent_mask,
                                    len(self.plates) - np.ndim(latent_mask)),
                                len(dims))
            axis_sum = tuple(range(-len(dims),0))

            # Compute the term
            phi_q = np.where(latent_mask_i, phi_q, 0)
            # TODO/FIXME: Use einsum here?
            Z = np.sum((phi_p-phi_q) * u_q, axis=axis_sum)

            L = L + Z

        return (np.sum(np.where(self.mask, L, 0))
                * self._plate_multiplier(self.plates,
                                         np.shape(L),
                                         np.shape(self.mask)))
Exemplo n.º 13
0
    def lower_bound_contribution(self, gradient=False, ignore_masked=True):
        r"""Compute E[ log p(X|parents) - log q(X) ]

        If deterministic annealing is used, the term E[ -log q(X) ] is
        divided by the anneling coefficient.  That is, phi and cgf of q
        are multiplied by the temperature (inverse annealing
        coefficient).
        
        """

        # Annealing temperature
        T = 1 / self.annealing

        # Messages from parents
        u_parents = self._message_from_parents()
        phi = self._distribution.compute_phi_from_parents(*u_parents)
        # G from parents
        L = self._distribution.compute_cgf_from_parents(*u_parents)

        # G for unobserved variables (ignored variables are handled properly
        # automatically)
        latent_mask = np.logical_not(self.observed)

        # G and F
        if np.all(self.observed):
            z = np.nan
        elif T == 1:
            z = -self.g
        else:
            z = -T * self.g
            ## TRIED THIS BUT IT WAS WRONG:
            ## z = -T * self.g + (1-T) * self.f
            ## if np.any(np.isnan(self.f)):
            ##     warnings.warn("F(x) not implemented for node %s. This "
            ##                   "is required for annealed lower bound "
            ##                   "computation." % self.__class__.__name__)
            ##
            ## It was wrong because the optimal q distribution has f which is
            ## weighted by 1/T and here the f of q is weighted by T so the
            ## total weight is 1, thus it cancels out with f of p.

        L = L + np.where(self.observed, self.f, z)

        for (phi_p, phi_q, u_q, dims) in zip(phi, self.phi, self.u, self.dims):
            # Form a mask which puts observed variables to zero and
            # broadcasts properly
            latent_mask_i = misc.add_trailing_axes(
                misc.add_leading_axes(latent_mask,
                                      len(self.plates) - np.ndim(latent_mask)),
                len(dims))
            axis_sum = tuple(range(-len(dims), 0))

            # Compute the term
            phi_q = np.where(latent_mask_i, phi_q, 0)
            # Apply annealing
            phi_diff = phi_p - T * phi_q
            # Handle 0 * -inf
            phi_diff = np.where(u_q != 0, phi_diff, 0)
            # TODO/FIXME: Use einsum here?
            Z = np.sum(phi_diff * u_q, axis=axis_sum)

            L = L + Z

        if ignore_masked:
            return (np.sum(np.where(self.mask, L, 0)) *
                    self.broadcasting_multiplier(self.plates, np.shape(L),
                                                 np.shape(self.mask)) *
                    np.prod(self.plates_multiplier))
        else:
            return (np.sum(L) *
                    self.broadcasting_multiplier(self.plates, np.shape(L)) *
                    np.prod(self.plates_multiplier))
Exemplo n.º 14
0
    def lower_bound_contribution(self, gradient=False, ignore_masked=True):
        r"""Compute E[ log p(X|parents) - log q(X) ]

        If deterministic annealing is used, the term E[ -log q(X) ] is
        divided by the anneling coefficient.  That is, phi and cgf of q
        are multiplied by the temperature (inverse annealing
        coefficient).
        
        """

        # Annealing temperature
        T = 1 / self.annealing
        
        # Messages from parents
        u_parents = self._message_from_parents()
        phi = self._distribution.compute_phi_from_parents(*u_parents)
        # G from parents
        L = self._distribution.compute_cgf_from_parents(*u_parents)

        # G for unobserved variables (ignored variables are handled properly
        # automatically)
        latent_mask = np.logical_not(self.observed)

        # G and F
        if np.all(self.observed):
            z = np.nan
        elif T == 1:
            z = -self.g
        else:
            z = -T * self.g
            ## TRIED THIS BUT IT WAS WRONG:
            ## z = -T * self.g + (1-T) * self.f
            ## if np.any(np.isnan(self.f)):
            ##     warnings.warn("F(x) not implemented for node %s. This "
            ##                   "is required for annealed lower bound "
            ##                   "computation." % self.__class__.__name__)
            ##
            ## It was wrong because the optimal q distribution has f which is
            ## weighted by 1/T and here the f of q is weighted by T so the
            ## total weight is 1, thus it cancels out with f of p.

        L = L + np.where(self.observed, self.f, z)

        for (phi_p, phi_q, u_q, dims) in zip(phi, self.phi, self.u, self.dims):
            # Form a mask which puts observed variables to zero and
            # broadcasts properly
            latent_mask_i = misc.add_trailing_axes(
                                misc.add_leading_axes(
                                    latent_mask,
                                    len(self.plates) - np.ndim(latent_mask)),
                                len(dims))
            axis_sum = tuple(range(-len(dims),0))

            # Compute the term
            phi_q = np.where(latent_mask_i, phi_q, 0)
            # Apply annealing
            # TODO/FIXME: Use einsum here?
            Z = np.sum((phi_p-T*phi_q) * u_q, axis=axis_sum)

            L = L + Z

        if ignore_masked:
            return (np.sum(np.where(self.mask, L, 0))
                    * self.broadcasting_multiplier(self.plates,
                                                   np.shape(L),
                                                   np.shape(self.mask))
                    * np.prod(self.plates_multiplier))
        else:
            return (np.sum(L)
                    * self.broadcasting_multiplier(self.plates,
                                                   np.shape(L))
                    * np.prod(self.plates_multiplier))
Exemplo n.º 15
0
    def _message_to_parent(self, index):

        # Compute the message, check plates, apply mask and sum over some plates
        if index >= len(self.parents):
            raise ValueError("Parent index larger than the number of parents")

        # Compute the message and mask
        (m, mask) = self._get_message_and_mask_to_parent(index)
        mask = misc.squeeze(mask)

        # Plates in the mask
        plates_mask = np.shape(mask)

        # The parent we're sending the message to
        parent = self.parents[index]

        # Plates with respect to the parent
        plates_self = self._plates_to_parent(index)

        # Plate multiplier of the parent
        multiplier_parent = self._plates_multiplier_from_parent(index)

        # Check if m is a logpdf function (for black-box variational inference)
        if callable(m):

            def m_function(*args):
                lpdf = m(*args)
                # Log pdf only contains plate axes!
                plates_m = np.shape(lpdf)
                r = (self.broadcasting_multiplier(plates_self, plates_m,
                                                  plates_mask, parent.plates) *
                     self.broadcasting_multiplier(self.plates_multiplier,
                                                  multiplier_parent))
                axes_msg = misc.axes_to_collapse(plates_m, parent.plates)
                m[i] = misc.sum_multiply(mask_i,
                                         m[i],
                                         r,
                                         axis=axes_msg,
                                         keepdims=True)

                # Remove leading singular plates if the parent does not have
                # those plate axes.
                m[i] = misc.squeeze_to_dim(m[i], len(shape_parent))

            return m_function
            raise NotImplementedError()

        # Compact the message to a proper shape
        for i in range(len(m)):

            # Empty messages are given as None. We can ignore those.
            if m[i] is not None:

                try:
                    r = self.broadcasting_multiplier(self.plates_multiplier,
                                                     multiplier_parent)
                except:
                    raise ValueError("The plate multipliers are incompatible. "
                                     "This node (%s) has %s and parent[%d] "
                                     "(%s) has %s" %
                                     (self.name, self.plates_multiplier, index,
                                      parent.name, multiplier_parent))

                ndim = len(parent.dims[i])
                # Source and target shapes
                if ndim > 0:
                    dims = misc.broadcasted_shape(
                        np.shape(m[i])[-ndim:], parent.dims[i])
                    from_shape = plates_self + dims
                else:
                    from_shape = plates_self
                to_shape = parent.get_shape(i)
                # Add variable axes to the mask
                mask_i = misc.add_trailing_axes(mask, ndim)
                # Apply mask and sum plate axes as necessary (and apply plate
                # multiplier)
                m[i] = r * misc.sum_multiply_to_plates(m[i],
                                                       mask_i,
                                                       to_plates=to_shape,
                                                       from_plates=from_shape,
                                                       ndim=0)

        return m
Exemplo n.º 16
0
    def compute_message_to_parent(self, parent, index, u, *u_parents):
        """
        Compute the message to a parent node.
        """

        if index == 0:

            # Shape(phi)    = [Nn,..,K,..,N0,Dd,..,D0]
            # Shape(L)      = [Nn,..,K,..,N0]
            # Shape(u)      = [Nn,..,N0,Dd,..,D0]
            # Shape(result) = [Nn,..,N0,K]

            # Compute g:
            # Shape(g)      = [Nn,..,K,..,N0]
            g = self.distribution.compute_cgf_from_parents(*(u_parents[1:]))
            # Reshape(g):
            # Shape(g)      = [Nn,..,N0,K]
            if np.ndim(g) < abs(self.cluster_plate):
                # Not enough axes, just add the cluster plate axis
                g = np.expand_dims(g, -1)
            else:
                # Move the cluster plate axis
                g = misc.moveaxis(g, self.cluster_plate, -1)

            # Compute phi:
            # Shape(phi)    = [Nn,..,K,..,N0,Dd,..,D0]
            phi = self.distribution.compute_phi_from_parents(*(u_parents[1:]))
            # Move phi axis:
            # Shape(phi)    = [Nn,..,N0,K,Dd,..,D0]
            for ind in range(len(phi)):
                if self.cluster_plate < 0:
                    axis_from = self.cluster_plate-self.ndims[ind]
                else:
                    raise RuntimeError("Cluster plate axis must be negative")
                axis_to = -1-self.ndims[ind]
                if np.ndim(phi[ind]) >= abs(axis_from):
                    # Cluster plate axis exists, move it to the correct position
                    phi[ind] = misc.moveaxis(phi[ind], axis_from, axis_to)
                else:
                    # No cluster plate axis, just add a new axis to the correct
                    # position, if phi has something on that axis
                    if np.ndim(phi[ind]) >= abs(axis_to):
                        phi[ind] = np.expand_dims(phi[ind], axis=axis_to)

            # Reshape u:
            # Shape(u)      = [Nn,..,N0,1,Dd,..,D0]
            u_self = list()
            for ind in range(len(u)):
                u_self.append(np.expand_dims(u[ind],
                                             axis=(-1-self.ndims[ind])))

            # Compute logpdf:
            # Shape(L)      = [Nn,..,N0,K]
            L = self.distribution.compute_logpdf(u_self, phi, g, 0, self.ndims)

            # Sum over other than the cluster dimensions? No!
            # Hmm.. I think the message passing method will do
            # that automatically

            m = [L]

            return m

        elif index >= 1:

            # Parent index for the distribution used for the
            # mixture.
            index_for_parent = index - 1

            # Reshape u:
            # Shape(u)      = [Nn,..1,..,N0,Dd,..,D0]
            u_self = list()
            for ind in range(len(u)):
                if self.cluster_plate < 0:
                    cluster_axis = self.cluster_plate - self.ndims[ind]
                else:
                    raise ValueError("Cluster plate axis must be negative")
                u_self.append(np.expand_dims(u[ind], axis=cluster_axis))

            # Message from the mixed distribution
            m = self.distribution.compute_message_to_parent(parent,
                                                            index_for_parent,
                                                            u_self,
                                                            *(u_parents[1:]))

            # Note: The cluster assignment probabilities can be considered as
            # weights to plate elements. These weights need to mapped properly
            # via the plate mapping of self.distribution. Otherwise, nested
            # mixtures won't work, or possibly not any distribution that does
            # something to the plates. Thus, use compute_weights_to_parent to
            # compute the transformations to the weight array properly.
            #
            # See issue #39 for more details.

            # Compute weights (i.e., cluster assignment probabilities) and map
            # the plates properly.
            p = misc.atleast_nd(u_parents[0][0], abs(self.cluster_plate))
            p = misc.moveaxis(p, -1, self.cluster_plate)
            p = self.distribution.compute_weights_to_parent(
                index_for_parent,
                p,
            )

            # Weigh the elements in the message array
            m = [mi * misc.add_trailing_axes(p, ndim)
                 #for (mi, ndim) in zip(m, self.ndims)]
                 for (mi, ndim) in zip(m, self.ndims_parents[index_for_parent])]

            return m
Exemplo n.º 17
0
    def compute_phi_from_parents(self, *u_parents, mask=True):
        """
        Compute the natural parameter vector given parent moments.
        """
        # Compute weighted average of the parameters

        # Cluster parameters
        Phi = self.distribution.compute_phi_from_parents(*(u_parents[1:]))
        # Contributions/weights/probabilities
        P = u_parents[0][0]

        phi = list()

        nans = False

        for ind in range(len(Phi)):
            # Compute element-wise product and then sum over K clusters.
            # Note that the dimensions aren't perfectly aligned because
            # the cluster dimension (K) may be arbitrary for phi, and phi
            # also has dimensions (Dd,..,D0) of the parameters.
            # Shape(phi)    = [Nn,..,K,..,N0,Dd,..,D0]
            # Shape(p)      = [Nn,..,N0,K]
            # Shape(result) = [Nn,..,N0,Dd,..,D0]
            # General broadcasting rules apply for Nn,..,N0, that is,
            # preceding dimensions may be missing or dimension may be
            # equal to one. Probably, shape(phi) has lots of missing
            # dimensions and/or dimensions that are one.

            if self.cluster_plate < 0:
                cluster_axis = self.cluster_plate - self.ndims[ind]
            else:
                raise RuntimeError("Cluster plate should be negative")

            # Move cluster axis to the last:
            # Shape(phi)    = [Nn,..,N0,Dd,..,D0,K]
            if np.ndim(Phi[ind]) >= abs(cluster_axis):
                phi.append(misc.moveaxis(Phi[ind], cluster_axis, -1))
            else:
                phi.append(Phi[ind][...,None])

            # Add axes to p:
            # Shape(p)      = [Nn,..,N0,K,1,..,1]
            p = misc.add_trailing_axes(P, self.ndims[ind])
            # Move cluster axis to the last:
            # Shape(p)      = [Nn,..,N0,1,..,1,K]
            p = misc.moveaxis(p, -(self.ndims[ind]+1), -1)

            # Handle zero probability cases. This avoids nans when p=0 and
            # phi=inf.
            phi[ind] = np.where(p != 0, phi[ind], 0)

            # Now the shapes broadcast perfectly and we can sum
            # p*phi over the last axis:
            # Shape(result) = [Nn,..,N0,Dd,..,D0]
            phi[ind] = misc.sum_product(p, phi[ind], axes_to_sum=-1)
            if np.any(np.isnan(phi[ind])):
                nans = True

        if nans:
            warnings.warn("The natural parameters of mixture distribution "
                          "contain nans. This may happen if you use fixed "
                          "parameters in your model. Technically, one possible "
                          "reason is that the cluster assignment probability "
                          "for some element is zero (p=0) and the natural "
                          "parameter of that cluster is -inf, thus "
                          "0*(-inf)=nan. Solution: Use parameters that assign "
                          "non-zero probabilities for the whole domain.")
            
        return phi
Exemplo n.º 18
0
    def compute_message_to_parent(self, parent, index, u, *u_parents):
        """
        Compute the message to a parent node.
        """

        if index == 0:

            # Shape(phi)    = [Nn,..,K,..,N0,Dd,..,D0]
            # Shape(L)      = [Nn,..,K,..,N0]
            # Shape(u)      = [Nn,..,N0,Dd,..,D0]
            # Shape(result) = [Nn,..,N0,K]

            # Compute g:
            # Shape(g)      = [Nn,..,K,..,N0]
            g = self.distribution.compute_cgf_from_parents(*(u_parents[1:]))
            # Reshape(g):
            # Shape(g)      = [Nn,..,N0,K]
            if np.ndim(g) < abs(self.cluster_plate):
                # Not enough axes, just add the cluster plate axis
                g = np.expand_dims(g, -1)
            else:
                # Move the cluster plate axis
                g = misc.moveaxis(g, self.cluster_plate, -1)

            # Compute phi:
            # Shape(phi)    = [Nn,..,K,..,N0,Dd,..,D0]
            phi = self.distribution.compute_phi_from_parents(*(u_parents[1:]))
            # Move phi axis:
            # Shape(phi)    = [Nn,..,N0,K,Dd,..,D0]
            for ind in range(len(phi)):
                if self.cluster_plate < 0:
                    axis_from = self.cluster_plate - self.ndims[ind]
                else:
                    raise RuntimeError("Cluster plate axis must be negative")
                axis_to = -1 - self.ndims[ind]
                if np.ndim(phi[ind]) >= abs(axis_from):
                    # Cluster plate axis exists, move it to the correct position
                    phi[ind] = misc.moveaxis(phi[ind], axis_from, axis_to)
                else:
                    # No cluster plate axis, just add a new axis to the correct
                    # position, if phi has something on that axis
                    if np.ndim(phi[ind]) >= abs(axis_to):
                        phi[ind] = np.expand_dims(phi[ind], axis=axis_to)

            # Reshape u:
            # Shape(u)      = [Nn,..,N0,1,Dd,..,D0]
            u_self = list()
            for ind in range(len(u)):
                u_self.append(
                    np.expand_dims(u[ind], axis=(-1 - self.ndims[ind])))

            # Compute logpdf:
            # Shape(L)      = [Nn,..,N0,K]
            L = self.distribution.compute_logpdf(u_self, phi, g, 0, self.ndims)

            # Sum over other than the cluster dimensions? No!
            # Hmm.. I think the message passing method will do
            # that automatically

            m = [L]

            return m

        elif index >= 1:

            # Parent index for the distribution used for the
            # mixture.
            index_for_parent = index - 1

            # Reshape u:
            # Shape(u)      = [Nn,..1,..,N0,Dd,..,D0]
            u_self = list()
            for ind in range(len(u)):
                if self.cluster_plate < 0:
                    cluster_axis = self.cluster_plate - self.ndims[ind]
                else:
                    raise ValueError("Cluster plate axis must be negative")
                u_self.append(np.expand_dims(u[ind], axis=cluster_axis))

            # Message from the mixed distribution
            m = self.distribution.compute_message_to_parent(
                parent, index_for_parent, u_self, *(u_parents[1:]))

            # Note: The cluster assignment probabilities can be considered as
            # weights to plate elements. These weights need to mapped properly
            # via the plate mapping of self.distribution. Otherwise, nested
            # mixtures won't work, or possibly not any distribution that does
            # something to the plates. Thus, use compute_weights_to_parent to
            # compute the transformations to the weight array properly.
            #
            # See issue #39 for more details.

            # Compute weights (i.e., cluster assignment probabilities) and map
            # the plates properly.
            p = misc.atleast_nd(u_parents[0][0], abs(self.cluster_plate))
            p = misc.moveaxis(p, -1, self.cluster_plate)
            p = self.distribution.compute_weights_to_parent(
                index_for_parent,
                p,
            )

            # Weigh the elements in the message array
            m = [
                mi * misc.add_trailing_axes(p, ndim)
                #for (mi, ndim) in zip(m, self.ndims)]
                for (mi, ndim) in zip(m, self.ndims_parents[index_for_parent])
            ]

            return m
Exemplo n.º 19
0
    def compute_phi_from_parents(self, *u_parents, mask=True):
        """
        Compute the natural parameter vector given parent moments.
        """
        # Compute weighted average of the parameters

        # Cluster parameters
        Phi = self.distribution.compute_phi_from_parents(*(u_parents[1:]))
        # Contributions/weights/probabilities
        P = u_parents[0][0]

        phi = list()

        nans = False

        for ind in range(len(Phi)):
            # Compute element-wise product and then sum over K clusters.
            # Note that the dimensions aren't perfectly aligned because
            # the cluster dimension (K) may be arbitrary for phi, and phi
            # also has dimensions (Dd,..,D0) of the parameters.
            # Shape(phi)    = [Nn,..,K,..,N0,Dd,..,D0]
            # Shape(p)      = [Nn,..,N0,K]
            # Shape(result) = [Nn,..,N0,Dd,..,D0]
            # General broadcasting rules apply for Nn,..,N0, that is,
            # preceding dimensions may be missing or dimension may be
            # equal to one. Probably, shape(phi) has lots of missing
            # dimensions and/or dimensions that are one.

            if self.cluster_plate < 0:
                cluster_axis = self.cluster_plate - self.ndims[ind]
            else:
                raise RuntimeError("Cluster plate should be negative")

            # Move cluster axis to the last:
            # Shape(phi)    = [Nn,..,N0,Dd,..,D0,K]
            if np.ndim(Phi[ind]) >= abs(cluster_axis):
                phi.append(misc.moveaxis(Phi[ind], cluster_axis, -1))
            else:
                phi.append(Phi[ind][..., None])

            # Add axes to p:
            # Shape(p)      = [Nn,..,N0,K,1,..,1]
            p = misc.add_trailing_axes(P, self.ndims[ind])
            # Move cluster axis to the last:
            # Shape(p)      = [Nn,..,N0,1,..,1,K]
            p = misc.moveaxis(p, -(self.ndims[ind] + 1), -1)

            # Handle zero probability cases. This avoids nans when p=0 and
            # phi=inf.
            phi[ind] = np.where(p != 0, phi[ind], 0)

            # Now the shapes broadcast perfectly and we can sum
            # p*phi over the last axis:
            # Shape(result) = [Nn,..,N0,Dd,..,D0]
            phi[ind] = misc.sum_product(p, phi[ind], axes_to_sum=-1)
            if np.any(np.isnan(phi[ind])):
                nans = True

        if nans:
            warnings.warn(
                "The natural parameters of mixture distribution "
                "contain nans. This may happen if you use fixed "
                "parameters in your model. Technically, one possible "
                "reason is that the cluster assignment probability "
                "for some element is zero (p=0) and the natural "
                "parameter of that cluster is -inf, thus "
                "0*(-inf)=nan. Solution: Use parameters that assign "
                "non-zero probabilities for the whole domain.")

        return phi
Exemplo n.º 20
0
    def compute_message_to_parent(self, parent, index, u, *u_parents):
        """
        Compute the message to a parent node.
        """

        if index == 0:

            # Shape(phi)    = [Nn,..,K,..,N0,Dd,..,D0]
            # Shape(L)      = [Nn,..,K,..,N0]
            # Shape(u)      = [Nn,..,N0,Dd,..,D0]
            # Shape(result) = [Nn,..,N0,K]

            # Compute g:
            # Shape(g)      = [Nn,..,K,..,N0]
            g = self.distribution.compute_cgf_from_parents(*(u_parents[1:]))
            # Reshape(g):
            # Shape(g)      = [Nn,..,N0,K]
            if np.ndim(g) < abs(self.cluster_plate):
                # Not enough axes, just add the cluster plate axis
                g = np.expand_dims(g, -1)
            else:
                # Move the cluster plate axis
                g = misc.moveaxis(g, self.cluster_plate, -1)

            # Compute phi:
            # Shape(phi)    = [Nn,..,K,..,N0,Dd,..,D0]
            phi = self.distribution.compute_phi_from_parents(*(u_parents[1:]))
            # Move phi axis:
            # Shape(phi)    = [Nn,..,N0,K,Dd,..,D0]
            for ind in range(len(phi)):
                if self.cluster_plate < 0:
                    axis_from = self.cluster_plate - self.ndims[ind]
                else:
                    raise RuntimeError("Cluster plate axis must be negative")
                axis_to = -1 - self.ndims[ind]
                if np.ndim(phi[ind]) >= abs(axis_from):
                    # Cluster plate axis exists, move it to the correct position
                    phi[ind] = misc.moveaxis(phi[ind], axis_from, axis_to)
                else:
                    # No cluster plate axis, just add a new axis to the correct
                    # position, if phi has something on that axis
                    if np.ndim(phi[ind]) >= abs(axis_to):
                        phi[ind] = np.expand_dims(phi[ind], axis=axis_to)

            # Reshape u:
            # Shape(u)      = [Nn,..,N0,1,Dd,..,D0]
            u_self = list()
            for ind in range(len(u)):
                u_self.append(
                    np.expand_dims(u[ind], axis=(-1 - self.ndims[ind])))

            # Compute logpdf:
            # Shape(L)      = [Nn,..,N0,K]
            L = self.distribution.compute_logpdf(u_self, phi, g, 0, self.ndims)

            # Sum over other than the cluster dimensions? No!
            # Hmm.. I think the message passing method will do
            # that automatically

            m = [L]

            return m

        elif index >= 1:

            # Parent index for the distribution used for the
            # mixture.
            index = index - 1

            # Reshape u:
            # Shape(u)      = [Nn,..1,..,N0,Dd,..,D0]
            u_self = list()
            for ind in range(len(u)):
                if self.cluster_plate < 0:
                    cluster_axis = self.cluster_plate - self.ndims[ind]
                else:
                    cluster_axis = self.cluster_plate
                u_self.append(np.expand_dims(u[ind], axis=cluster_axis))

            # Message from the mixed distribution
            m = self.distribution.compute_message_to_parent(
                parent, index, u_self, *(u_parents[1:]))

            # Weigh the messages with the responsibilities
            for i in range(len(m)):

                # Shape(m)      = [Nn,..,K,..,N0,Dd,..,D0]
                # Shape(p)      = [Nn,..,N0,K]
                # Shape(result) = [Nn,..,K,..,N0,Dd,..,D0]

                # Number of axes for the variable dimensions for
                # the parent message.
                D = self.ndims_parents[index][i]

                # Responsibilities for clusters are the first
                # parent's first moment:
                # Shape(p)      = [Nn,..,N0,K]
                p = u_parents[0][0]
                # Move the cluster axis to the proper place:
                # Shape(p)      = [Nn,..,K,..,N0]
                p = misc.atleast_nd(p, abs(self.cluster_plate))
                p = misc.moveaxis(p, -1, self.cluster_plate)
                # Add axes for variable dimensions to the contributions
                # Shape(p)      = [Nn,..,K,..,N0,1,..,1]
                p = misc.add_trailing_axes(p, D)

                if self.cluster_plate < 0:
                    # Add the variable dimensions
                    cluster_axis = self.cluster_plate - D

                # Add axis for clusters:
                # Shape(m)      = [Nn,..,1,..,N0,Dd,..,D0]
                #m[i] = np.expand_dims(m[i], axis=cluster_axis)

                #
                # TODO: You could do summing here already so that
                # you wouldn't compute huge matrices as
                # intermediate result. Use einsum.

                # Compute the message contributions for each
                # cluster:
                # Shape(result) = [Nn,..,K,..,N0,Dd,..,D0]
                m[i] = m[i] * p

            return m
Exemplo n.º 21
0
    def compute_message_to_parent(self, parent, index, u, *u_parents):
        """
        Compute the message to a parent node.
        """

        if index == 0:

            # Shape(phi)    = [Nn,..,K,..,N0,Dd,..,D0]
            # Shape(L)      = [Nn,..,K,..,N0]
            # Shape(u)      = [Nn,..,N0,Dd,..,D0]
            # Shape(result) = [Nn,..,N0,K]

            # Compute g:
            # Shape(g)      = [Nn,..,K,..,N0]
            g = self.distribution.compute_cgf_from_parents(*(u_parents[1:]))
            # Reshape(g):
            # Shape(g)      = [Nn,..,N0,K]
            if np.ndim(g) < abs(self.cluster_plate):
                # Not enough axes, just add the cluster plate axis
                g = np.expand_dims(g, -1)
            else:
                # Move the cluster plate axis
                g = misc.moveaxis(g, self.cluster_plate, -1)

            # Compute phi:
            # Shape(phi)    = [Nn,..,K,..,N0,Dd,..,D0]
            phi = self.distribution.compute_phi_from_parents(*(u_parents[1:]))
            # Move phi axis:
            # Shape(phi)    = [Nn,..,N0,K,Dd,..,D0]
            for ind in range(len(phi)):
                if self.cluster_plate < 0:
                    axis_from = self.cluster_plate - self.ndims[ind]
                else:
                    raise RuntimeError("Cluster plate axis must be negative")
                axis_to = -1 - self.ndims[ind]
                if np.ndim(phi[ind]) >= abs(axis_from):
                    # Cluster plate axis exists, move it to the correct position
                    phi[ind] = misc.moveaxis(phi[ind], axis_from, axis_to)
                else:
                    # No cluster plate axis, just add a new axis to the correct
                    # position, if phi has something on that axis
                    if np.ndim(phi[ind]) >= abs(axis_to):
                        phi[ind] = np.expand_dims(phi[ind], axis=axis_to)

            # Reshape u:
            # Shape(u)      = [Nn,..,N0,1,Dd,..,D0]
            u_self = list()
            for ind in range(len(u)):
                u_self.append(np.expand_dims(u[ind], axis=(-1 - self.ndims[ind])))

            # Compute logpdf:
            # Shape(L)      = [Nn,..,N0,K]
            L = self.distribution.compute_logpdf(u_self, phi, g, 0, self.ndims)

            # Sum over other than the cluster dimensions? No!
            # Hmm.. I think the message passing method will do
            # that automatically

            m = [L]

            return m

        elif index >= 1:

            # Parent index for the distribution used for the
            # mixture.
            index = index - 1

            # Reshape u:
            # Shape(u)      = [Nn,..1,..,N0,Dd,..,D0]
            u_self = list()
            for ind in range(len(u)):
                if self.cluster_plate < 0:
                    cluster_axis = self.cluster_plate - self.ndims[ind]
                else:
                    cluster_axis = self.cluster_plate
                u_self.append(np.expand_dims(u[ind], axis=cluster_axis))

            # Message from the mixed distribution
            m = self.distribution.compute_message_to_parent(parent, index, u_self, *(u_parents[1:]))

            # Weigh the messages with the responsibilities
            for i in range(len(m)):

                # Shape(m)      = [Nn,..,K,..,N0,Dd,..,D0]
                # Shape(p)      = [Nn,..,N0,K]
                # Shape(result) = [Nn,..,K,..,N0,Dd,..,D0]

                # Number of axes for the variable dimensions for
                # the parent message.
                D = self.ndims_parents[index][i]

                # Responsibilities for clusters are the first
                # parent's first moment:
                # Shape(p)      = [Nn,..,N0,K]
                p = u_parents[0][0]
                # Move the cluster axis to the proper place:
                # Shape(p)      = [Nn,..,K,..,N0]
                p = misc.atleast_nd(p, abs(self.cluster_plate))
                p = misc.moveaxis(p, -1, self.cluster_plate)
                # Add axes for variable dimensions to the contributions
                # Shape(p)      = [Nn,..,K,..,N0,1,..,1]
                p = misc.add_trailing_axes(p, D)

                if self.cluster_plate < 0:
                    # Add the variable dimensions
                    cluster_axis = self.cluster_plate - D

                # Add axis for clusters:
                # Shape(m)      = [Nn,..,1,..,N0,Dd,..,D0]
                # m[i] = np.expand_dims(m[i], axis=cluster_axis)

                #
                # TODO: You could do summing here already so that
                # you wouldn't compute huge matrices as
                # intermediate result. Use einsum.

                # Compute the message contributions for each
                # cluster:
                # Shape(result) = [Nn,..,K,..,N0,Dd,..,D0]
                m[i] = m[i] * p

            return m
Exemplo n.º 22
0
    def _message_to_parent(self, index):

        # Compute the message, check plates, apply mask and sum over some plates
        if index >= len(self.parents):
            raise ValueError("Parent index larger than the number of parents")

        # Compute the message and mask
        (m, mask) = self._get_message_and_mask_to_parent(index)
        mask = misc.squeeze(mask)

        # Plates in the mask
        plates_mask = np.shape(mask)

        # The parent we're sending the message to
        parent = self.parents[index]

        # Plates with respect to the parent
        plates_self = self._plates_to_parent(index)

        # Plate multiplier of the parent
        multiplier_parent = self._plates_multiplier_from_parent(index)

        # Check if m is a logpdf function (for black-box variational inference)
        if callable(m):
            def m_function(*args):
                lpdf = m(*args)
                # Log pdf only contains plate axes!
                plates_m = np.shape(lpdf)
                r = (self.broadcasting_multiplier(plates_self,
                                                  plates_m,
                                                  plates_mask,
                                                  parent.plates) *
                     self.broadcasting_multiplier(self.plates_multiplier,
                                                  multiplier_parent))
                axes_msg = misc.axes_to_collapse(plates_m, parent.plates)
                m[i] = misc.sum_multiply(mask_i, m[i], r,
                                         axis=axes_msg,
                                         keepdims=True)

                # Remove leading singular plates if the parent does not have
                # those plate axes.
                m[i] = misc.squeeze_to_dim(m[i], len(shape_parent))

            return m_function
            raise NotImplementedError()

        # Compact the message to a proper shape
        for i in range(len(m)):

            # Empty messages are given as None. We can ignore those.
            if m[i] is not None:

                try:
                    r = self.broadcasting_multiplier(self.plates_multiplier,
                                                     multiplier_parent)
                except:
                    raise ValueError("The plate multipliers are incompatible. "
                                     "This node (%s) has %s and parent[%d] "
                                     "(%s) has %s"
                                     % (self.name,
                                        self.plates_multiplier,
                                        index,
                                        parent.name,
                                        multiplier_parent))

                ndim = len(parent.dims[i])
                # Source and target shapes
                if ndim > 0:
                    dims = misc.broadcasted_shape(np.shape(m[i])[-ndim:],
                                                  parent.dims[i])
                    from_shape = plates_self + dims
                else:
                    from_shape = plates_self
                to_shape = parent.get_shape(i)
                # Add variable axes to the mask
                mask_i = misc.add_trailing_axes(mask, ndim)
                # Apply mask and sum plate axes as necessary (and apply plate
                # multiplier)
                m[i] = r * misc.sum_multiply_to_plates(m[i], mask_i,
                                                       to_plates=to_shape,
                                                       from_plates=from_shape,
                                                       ndim=0)

        return m
Exemplo n.º 23
0
    def compute_message_to_parent(self, parent, index, u, *u_parents):
        """
        Compute the message to a parent node.
        """

        if index == 0:

            # Shape(phi)    = [Nn,..,K,..,N0,Dd,..,D0]
            # Shape(L)      = [Nn,..,K,..,N0]
            # Shape(u)      = [Nn,..,N0,Dd,..,D0]
            # Shape(result) = [Nn,..,N0,K]

            # Compute g:
            # Shape(g)      = [Nn,..,K,..,N0]
            g = self.raw_distribution.compute_cgf_from_parents(
                *(u_parents[1:]))
            # Reshape(g):
            # Shape(g)      = [Nn,..,N0,K]
            if np.ndim(g) < abs(self.cluster_plate):
                # Not enough axes, just add the cluster plate axis
                g = np.expand_dims(g, -1)
            else:
                # Move the cluster plate axis
                g = misc.moveaxis(g, self.cluster_plate, -1)

            # Compute phi:
            # Shape(phi)    = [Nn,..,K,..,N0,Dd,..,D0]
            phi = self.raw_distribution.compute_phi_from_parents(
                *(u_parents[1:]))

            # Reshape u:
            # Shape(u) =    = [Nn,..,1,..,N0,Dd,..,D0]
            u_reshaped = [
                np.expand_dims(ui, self.cluster_plate - ndimi)
                if np.ndim(ui) >= abs(self.cluster_plate - ndimi) else ui
                for (ui, ndimi) in zip(u, self.ndims)
            ]

            # Compute logpdf:
            # Shape(L)      = [Nn,..,K,..,N0]
            L = self.raw_distribution.compute_logpdf(
                u_reshaped,
                phi,
                g,
                0,
                self.ndims,
            )

            # Move axis:
            # Shape(L)      = [Nn,..,N0,K]
            L = np.moveaxis(L, self.cluster_plate, -1)

            m = [L]

            return m

        elif index >= 1:

            # Parent index for the distribution used for the
            # mixture.
            index_for_parent = index - 1

            # Reshape u:
            # Shape(u_self)  = [Nn,..1,..,N0,Dd,..,D0]
            u_self = list()
            for ind in range(len(u)):
                if self.cluster_plate < 0:
                    cluster_axis = self.cluster_plate - self.ndims[ind]
                else:
                    raise ValueError("Cluster plate axis must be negative")
                u_self.append(np.expand_dims(u[ind], axis=cluster_axis))

            # Message from the mixed distribution
            # Shape(m)       = [Nn,..,K,..,N0,Dd,..,D0]
            m = self.raw_distribution.compute_message_to_parent(
                parent, index_for_parent, u_self, *(u_parents[1:]))

            # Note: The cluster assignment probabilities can be considered as
            # weights to plate elements. These weights need to mapped properly
            # via the plate mapping of self.distribution. Otherwise, nested
            # mixtures won't work, or possibly not any distribution that does
            # something to the plates. Thus, use compute_weights_to_parent to
            # compute the transformations to the weight array properly.
            #
            # See issue #39 for more details.

            # Compute weights (i.e., cluster assignment probabilities) and map
            # the plates properly.
            # Shape(p)       = [Nn,..,K,..,N0]
            p = misc.atleast_nd(u_parents[0][0], abs(self.cluster_plate))
            p = misc.moveaxis(p, -1, self.cluster_plate)
            p = self.raw_distribution.compute_weights_to_parent(
                index_for_parent,
                p,
            )

            # Weigh the elements in the message array
            #
            # TODO/FIXME: This may result in huge intermediate arrays. Need to
            # use einsum!
            m = [
                mi * misc.add_trailing_axes(p, ndim)
                #for (mi, ndim) in zip(m, self.ndims)]
                for (mi, ndim) in zip(m, self.ndims_parents[index_for_parent])
            ]

            return m