Example #1
0
    def _compute_moments(self, *u_parents):


        # Compute the number of plate axes for each node
        plate_counts0 = [(np.ndim(u_parent[0]) - len(keys))
                         for (keys,u_parent) in zip(self.in_keys, u_parents)]
        plate_counts1 = [(np.ndim(u_parent[1]) - 2*len(keys))
                         for (keys,u_parent) in zip(self.in_keys, u_parents)]
        # The number of plate axes for the output
        N0 = max(plate_counts0)
        N1 = max(plate_counts1)
        # The total number of unique keys used (keys are 0,1,...,N_keys-1)
        D = self.N_keys

        #
        # Compute the mean
        #
        out_all_keys = list(range(D+N0-1, D-1, -1)) + self.out_keys
        #nodes_dim_keys = self.nodes_dim_keys
        in_all_keys = [list(range(D+plate_count-1, D-1, -1)) + keys
                      for (plate_count, keys) in zip(plate_counts0, 
                                                     self.in_keys)]
        u0 = [u[0] for u in u_parents]
        
        args = misc.zipper_merge(u0, in_all_keys) + [out_all_keys]
        x0 = np.einsum(*args)

        #
        # Compute the covariance
        #
        out_all_keys = (list(range(2*D+N1-1, 2*D-1, -1)) 
                        + [D+key for key in self.out_keys] 
                        + self.out_keys)
        in_all_keys = [list(range(2*D+plate_count-1, 2*D-1, -1)) 
                       + [D+key for key in node_keys]
                       + node_keys
                       for (plate_count, node_keys) in zip(plate_counts1, 
                                                           self.in_keys)]
        u1 = [u[1] for u in u_parents]
        args = misc.zipper_merge(u1, in_all_keys) + [out_all_keys]
        x1 = np.einsum(*args)

        if not self.gaussian_gamma:
            return [x0, x1]

        # Compute Gaussian-gamma specific moments
        x2 = 1
        x3 = 0
        for i in range(len(u_parents)):
            x2 = x2 * u_parents[i][2]
            x3 = x3 + u_parents[i][3]

        return [x0, x1, x2, x3]
Example #2
0
def MultiMixture(thetas, *mixture_args, **kwargs):
    """Creates a mixture over several axes using as many categorical variables.

    The mixings are assumed to be separate, that is, inner mixings don't affect
    the parameters of outer mixings.
    """
    thetas = list(thetas)
    N = len(thetas)
    # Add trailing plate axes to thetas because you assume that each
    # mixed axis is separate from the others.
    thetas = [
        theta[(Ellipsis, ) + i * (None, )] for (i, theta) in enumerate(thetas)
    ]
    args = (thetas[:1] +
            list(misc.zipper_merge(
                (N - 1) * [Mixture], thetas[1:])) + list(mixture_args))
    return Mixture(*args, **kwargs)
Example #3
0
def MultiMixture(thetas, *mixture_args, **kwargs):
    """Creates a mixture over several axes using as many categorical variables.

    The mixings are assumed to be separate, that is, inner mixings don't affect
    the parameters of outer mixings.
    """
    thetas = list(thetas)
    N = len(thetas)
    # Add trailing plate axes to thetas because you assume that each
    # mixed axis is separate from the others.
    thetas = [theta[(Ellipsis,) + i*(None,)]
              for (i, theta) in enumerate(thetas)]
    args = (
        thetas[:1]
        + list(misc.zipper_merge((N-1) * [Mixture], thetas[1:]))
        + list(mixture_args)
    )
    return Mixture(*args, **kwargs)
Example #4
0
    def _compute_function(self, *x_parents):

        # TODO: Add unit tests for this function

        (xs, alphas) = (
            (x_parents, 1) if not self.gaussian_gamma else
            zip(*x_parents)
        )

        # Add Ellipsis for the plates
        in_keys = [[Ellipsis] + k for k in self.in_keys]
        out_keys = [Ellipsis] + self.out_keys

        samples_and_keys = misc.zipper_merge(xs, in_keys)
        y = np.einsum(*(samples_and_keys + [out_keys]))

        return (
            y if not self.gaussian_gamma else
            (y, misc.multiply(*alphas))
        )
Example #5
0
    def _compute_message(*arrays, plates_from=(), plates_to=(), ndim=0):
        """
        A general function for computing messages by sum-multiply

        The function computes the product of the input arrays and then sums to
        the requested plates.
        """

        # Check that the plates broadcast properly
        if not misc.is_shape_subset(plates_to, plates_from):
            raise ValueError("plates_to must be broadcastable to plates_from")

        # Compute the explicit shape of the product
        shapes = [np.shape(array) for array in arrays]
        arrays_shape = misc.broadcasted_shape(*shapes)

        # Compute plates and dims that are present
        if ndim == 0:
            arrays_plates = arrays_shape
            dims = ()
        else:
            arrays_plates = arrays_shape[:-ndim]
            dims = arrays_shape[-ndim:]

        # Compute the correction term.  If some of the plates that should be
        # summed are actually broadcasted, one must multiply by the size of the
        # corresponding plate
        r = Node.broadcasting_multiplier(plates_from, arrays_plates, plates_to)

        # For simplicity, make the arrays equal ndim
        arrays = misc.make_equal_ndim(*arrays)
        
        # Keys for the input plates: (N-1, N-2, ..., 0)
        nplates = len(arrays_plates)
        in_plate_keys = list(range(nplates))

        # Keys for the output plates
        out_plate_keys = [key 
                          for key in in_plate_keys
                          if key < len(plates_to) and plates_to[-key-1] != 1]

        # Keys for the dims
        dim_keys = list(range(nplates, nplates+ndim))

        # Total input and output keys
        in_keys = len(arrays) * [in_plate_keys + dim_keys]
        out_keys = out_plate_keys + dim_keys

        # Compute the sum-product with correction
        einsum_args = misc.zipper_merge(arrays, in_keys) + [out_keys]
        y = r * np.einsum(*einsum_args)

        # Reshape the result and apply correction
        nplates_result = min(len(plates_to), len(arrays_plates))
        if nplates_result == 0:
            plates_result = []
        else:
            plates_result = [min(plates_to[ind], arrays_plates[ind])
                             for ind in range(-nplates_result, 0)]
        y = np.reshape(y, plates_result + list(dims))

        return y
Example #6
0
    def _compute_message(*arrays, plates_from=(), plates_to=(), ndim=0):
        """
        A general function for computing messages by sum-multiply

        The function computes the product of the input arrays and then sums to
        the requested plates.
        """

        # Check that the plates broadcast properly
        if not misc.is_shape_subset(plates_to, plates_from):
            raise ValueError("plates_to must be broadcastable to plates_from")

        # Compute the explicit shape of the product
        shapes = [np.shape(array) for array in arrays]
        arrays_shape = misc.broadcasted_shape(*shapes)

        # Compute plates and dims that are present
        if ndim == 0:
            arrays_plates = arrays_shape
            dims = ()
        else:
            arrays_plates = arrays_shape[:-ndim]
            dims = arrays_shape[-ndim:]

        # Compute the correction term.  If some of the plates that should be
        # summed are actually broadcasted, one must multiply by the size of the
        # corresponding plate
        r = Node.broadcasting_multiplier(plates_from, arrays_plates, plates_to)

        # For simplicity, make the arrays equal ndim
        arrays = misc.make_equal_ndim(*arrays)

        # Keys for the input plates: (N-1, N-2, ..., 0)
        nplates = len(arrays_plates)
        in_plate_keys = list(range(nplates))

        # Keys for the output plates
        out_plate_keys = [
            key for key in in_plate_keys
            if key < len(plates_to) and plates_to[-key - 1] != 1
        ]

        # Keys for the dims
        dim_keys = list(range(nplates, nplates + ndim))

        # Total input and output keys
        in_keys = len(arrays) * [in_plate_keys + dim_keys]
        out_keys = out_plate_keys + dim_keys

        # Compute the sum-product with correction
        einsum_args = misc.zipper_merge(arrays, in_keys) + [out_keys]
        y = r * np.einsum(*einsum_args)

        # Reshape the result and apply correction
        nplates_result = min(len(plates_to), len(arrays_plates))
        if nplates_result == 0:
            plates_result = []
        else:
            plates_result = [
                min(plates_to[ind], arrays_plates[ind])
                for ind in range(-nplates_result, 0)
            ]
        y = np.reshape(y, plates_result + list(dims))

        return y
Example #7
0
    def _compute_moments(self, *u_parents):


        # Compute the number of plate axes for each node
        plate_counts0 = [
            (np.ndim(u_parent[0]) - len(keys))
            for (keys,u_parent) in zip(self.in_keys, u_parents)
        ]
        plate_counts1 = [
            (
                # Gaussian moments: Use second moments "matrix"
                (np.ndim(u_parent[1]) - 2*len(keys))
                if not is_const else
                # Delta moments: Use first moment "vector"
                (np.ndim(u_parent[0]) - len(keys))
            )
            for (keys, u_parent, is_const) in zip(
                    self.in_keys,
                    u_parents,
                    self.is_constant
            )
        ]
        # The number of plate axes for the output
        N0 = max(plate_counts0)
        N1 = max(plate_counts1)
        # The total number of unique keys used (keys are 0,1,...,N_keys-1)
        D = self.N_keys

        #
        # Compute the mean
        #
        out_all_keys = list(range(D+N0-1, D-1, -1)) + self.out_keys
        #nodes_dim_keys = self.nodes_dim_keys
        in_all_keys = [list(range(D+plate_count-1, D-1, -1)) + keys
                      for (plate_count, keys) in zip(plate_counts0,
                                                     self.in_keys)]
        u0 = [u[0] for u in u_parents]

        args = misc.zipper_merge(u0, in_all_keys) + [out_all_keys]
        x0 = np.einsum(*args)

        #
        # Compute the covariance
        #
        out_all_keys = (list(range(2*D+N1-1, 2*D-1, -1))
                        + [D+key for key in self.out_keys]
                        + self.out_keys)
        in_all_keys = [
            x
            for (plate_count, node_keys, is_const) in zip(
                    plate_counts1,
                    self.in_keys,
                    self.is_constant,
            )
            for x in (
                    # Gaussian moments: Use the second moment
                    [
                        list(range(2*D+plate_count-1, 2*D-1, -1))
                        + [D+key for key in node_keys]
                        + node_keys
                    ]
                    if not is_const else
                    # Delta moments: Use the first moment tiwce
                    [
                        (
                            list(range(2*D+plate_count-1, 2*D-1, -1))
                            + [D+key for key in node_keys]
                        ),
                        (
                            list(range(2*D+plate_count-1, 2*D-1, -1))
                            + node_keys
                        ),
                    ]
            )
        ]
        u1 = [
            x
            for (u, is_const) in zip(u_parents, self.is_constant)
            for x in (
                    # Gaussian moments: Use the second moment
                    [u[1]]
                    if not is_const else
                    # Delta moments: Use the first moment twice
                    [u[0], u[0]]
            )
        ]
        args = misc.zipper_merge(u1, in_all_keys) + [out_all_keys]
        x1 = np.einsum(*args)

        if not self.gaussian_gamma:
            return [x0, x1]

        # Compute Gaussian-gamma specific moments
        x2 = 1
        x3 = 0
        for i in range(len(u_parents)):
            x2 = x2 * (1 if self.is_constant[i] else u_parents[i][2])
            x3 = x3 + (0 if self.is_constant[i] else u_parents[i][3])

        return [x0, x1, x2, x3]