def full_wicker(in_var, num_channels_prod, num_channels_sums, num_classes=10, edge_size=28, first_depthwise=False, supervised=True): stack_size = int(np.ceil(np.log2(edge_size))) - 1 if first_depthwise: prod0 = ConvProductsDepthwise( in_var, padding='full', kernel_size=2, strides=1, spatial_dim_sizes=[edge_size, edge_size]) else: prod0 = ConvProducts( in_var, num_channels=num_channels_prod[0], padding='full', kernel_size=2, strides=1, spatial_dim_sizes=[edge_size, edge_size]) h = LocalSums(prod0, num_channels=num_channels_sums[0]) for i in range(stack_size): dilation_rate = 2 ** (i + 1) h = ConvProductsDepthwise( h, padding='full', kernel_size=2, strides=1, dilation_rate=dilation_rate) h = LocalSums(h, num_channels=num_channels_sums[1 + i]) full_scope_prod = ConvProductsDepthwise( h, padding='wicker_top', kernel_size=2, strides=1, dilation_rate=2 ** (stack_size + 1)) if supervised: class_roots = ParallelSums(full_scope_prod, num_sums=num_classes) root = Sum(class_roots) return root, class_roots return Sum(full_scope_prod), None
def wicker_convspn_two_non_overlapping( in_var, num_channels_prod, num_channels_sums, num_classes=10, edge_size=28, first_depthwise=False, supervised=True): stack_size = int(np.ceil(np.log2(edge_size))) - 2 if first_depthwise: prod0 = ConvProductsDepthwise( in_var, padding='valid', kernel_size=2, strides=2, spatial_dim_sizes=[edge_size, edge_size]) else: prod0 = ConvProducts( in_var, num_channels=num_channels_prod[0], padding='valid', kernel_size=2, strides=2, spatial_dim_sizes=[edge_size, edge_size]) sum0 = LocalSums(prod0, num_channels=num_channels_sums[0]) prod1 = ConvProductsDepthwise(sum0, padding='valid', kernel_size=2, strides=2) h = LocalSums(prod1, num_channels=num_channels_sums[1]) for i in range(stack_size): dilation_rate = 2 ** i h = ConvProductsDepthwise( h, padding='full', kernel_size=2, strides=1, dilation_rate=dilation_rate) h = LocalSums(h, num_channels=num_channels_sums[2 + i]) full_scope_prod = ConvProductsDepthwise( h, padding='wicker_top', kernel_size=2, strides=1, dilation_rate=2 ** stack_size) if supervised: class_roots = ParallelSums(full_scope_prod, num_sums=num_classes) root = Sum(class_roots) return root, class_roots return Sum(full_scope_prod), None
def _log_likelihood(self): """Computes log(p(X)) by creating a copy of the root node without latent indicators. Returns: A Tensor of shape [batch, 1] corresponding to the log likelihood of the data. """ if isinstance(self._root, BaseSum): marginalizing_root = self._marginalizing_root or Sum( *self._root.values, weights=self._root.weights) else: marginalizing_root = self._marginalizing_root or BlockSum( self._root.values[0], weights=self._root.weights, num_sums_per_block=1) return self._log_value.get_value(marginalizing_root)
def generate(self, *inputs, rnd=None, root_name=None): """Generate the SPN. Args: inputs (input_like): Inputs to the generated SPN. rnd (Random): Optional. A custom instance of a random number generator ``random.Random`` that will be used instead of the default global instance. This permits using a generator with a custom state independent of the global one. root_name (str): Name of the root node of the generated SPN. Returns: Sum: Root node of the generated SPN. """ self.__debug1( "Generating dense SPN (num_decomps=%s, num_subsets=%s," " num_mixtures=%s, input_dist=%s, num_input_mixtures=%s)", self.num_decomps, self.num_subsets, self.num_mixtures, self.input_dist, self.num_input_mixtures) inputs = [Input.as_input(i) for i in inputs] input_set = self.__generate_set(inputs) self.__debug1("Found %s distinct input scopes", len(input_set)) # Create root root = Sum(name=root_name) # Subsets left to process subsets = deque() subsets.append(DenseSPNGenerator.SubsetInfo(level=1, subset=input_set, parents=[root])) # Process subsets layer by layer self.__decomp_id = 1 # Id number of a decomposition, for info only while subsets: # Process whole layer (all subsets at the same level) level = subsets[0].level self.__debug1("Processing level %s", level) while subsets and subsets[0].level == level: subset = subsets.popleft() new_subsets = self.__add_decompositions(subset, rnd) for s in new_subsets: subsets.append(s) # If NodeType is LAYER, convert the generated graph with LayerNodes return (self.convert_to_layer_nodes(root) if self.node_type == DenseSPNGenerator.NodeType.LAYER else root)
def build(self, *sample_inputs, class_input=None, num_vars=None, num_vals=None, seed=None): """Build the SPN graph of the model. The model can be built on top of any ``sample_inputs``. Otherwise, if no sample inputs are provided, the model will internally crate a single IndicatorLeaf node to represent the input data samples. In such case, ``num_vars`` and ``num_vals`` must be specified. Similarly, if ``class_input`` is provided, it is used as a source of class indicators of the root sum node combining sub-SPNs modeling particular classes. Otherwise, an internal IndicatorLeaf node is created for this purpose. Args: *sample_inputs (input_like): Optional. Inputs to the model providing data samples. class_input (input_like): Optional. Input providing class indicators. num_vars (int): Optional. Number of variables in each sample. Must only be provided if ``sample_inputs`` are not given. num_vals (int or list of int): Optional. Number of values of each variable. Can be a single value or a list of values, one for each of ``num_vars`` variables. Must only be provided if ``sample_inputs`` are not given. seed (int): Optional. Seed used for the dense SPN generator. Returns: Sum: Root node of the generated model. """ if not sample_inputs: if num_vars is None: raise ValueError("num_vars must be given when sample_inputs are not") if num_vals is None: raise ValueError("num_vals must be given when sample_inputs are not") if not isinstance(num_vars, int) or num_vars < 1: raise ValueError("num_vars must be an integer > 0") if not isinstance(num_vals, int) or num_vals < 1: raise ValueError("num_vals must be an integer > 0") if self._num_classes > 1: self.__info("Building a discrete dense model with %d classes" % self._num_classes) else: self.__info("Building a 1-class discrete dense model") # Create IndicatorLeaf if inputs not given if not sample_inputs: self._sample_latent_indicators = IndicatorLeaf(num_vars=num_vars, num_vals=num_vals, name="SampleIndicatorLeaf") self._sample_inputs = [Input(self._sample_latent_indicators)] else: self._sample_inputs = tuple(Input.as_input(i) for i in sample_inputs) if self._num_classes > 1: if class_input is None: self._class_latent_indicators = IndicatorLeaf(num_vars=1, num_vals=self._num_classes, name="ClassIndicatorLeaf") self._class_input = Input(self._class_latent_indicators) else: self._class_input = Input.as_input(class_input) # Generate structure dense_gen = DenseSPNGenerator(num_decomps=self._num_decomps, num_subsets=self._num_subsets, num_mixtures=self._num_mixtures, input_dist=self._input_dist, num_input_mixtures=self._num_input_mixtures, balanced=True) rnd = random.Random(seed) if self._num_classes == 1: # One-class self._root = dense_gen.generate(*self._sample_inputs, rnd=rnd, root_name='Root') else: # Multi-class: create sub-SPNs sub_spns = [] for c in range(self._num_classes): rnd_copy = random.Random() rnd_copy.setstate(rnd.getstate()) with tf.name_scope("Class%d" % c): sub_root = dense_gen.generate(*self._sample_inputs, rnd=rnd_copy) if self.__is_debug1(): self.__debug1("sub-SPN %d has %d nodes" % (c, sub_root.get_num_nodes())) sub_spns.append(sub_root) # Create root self._root = Sum(*sub_spns, latent_indicators=self._class_input, name="Root") if self.__is_debug1(): self.__debug1("SPN graph has %d nodes" % self._root.get_num_nodes()) # Generate weight nodes self.__debug1("Generating weight nodes") generate_weights(self._root, initializer=self._weight_initializer) if self.__is_debug1(): self.__debug1("SPN graph has %d nodes and %d TF ops" % ( self._root.get_num_nodes(), self._root.get_tf_graph_size())) return self._root
class DiscreteDenseModel(Model): """Basic dense SPN model operating on discrete data. If `num_classes` is greater than 1, a multi-class model is created by generating multiple parallel dense models (one for each class) and combining them with a sum node with an explicit latent class variable. Args: num_vars (int): Number of discrete random variables representing data samples. num_vals (int): Number of values of each random variable. num_classes (int): Number of classes assumed by the model. num_decomps (int): Number of decompositions at each level of dense SPN. num_subsets (int): Number of variable sub-sets for each decomposition. num_mixtures (int): Number of mixtures (sums) for each variable subset. input_dist (InputDist): Determines how IndicatorLeaf of the discrete variables for data samples are connected to the model. num_input_mixtures (int): Number of mixtures used representing each discrete data variable (mixing the data variable IndicatorLeaf) when ``input_dist`` is set to ``MIXTURE``. If set to ``None``, ``num_mixtures`` is used. weight_init_value: Initial value of the weights. """ __logger = get_logger() __info = __logger.info __debug1 = __logger.debug1 __is_debug1 = __logger.is_debug1 __debug2 = __logger.debug2 __is_debug2 = __logger.is_debug2 def __init__(self, num_classes, num_decomps, num_subsets, num_mixtures, input_dist=DenseSPNGenerator.InputDist.MIXTURE, num_input_mixtures=None, weight_initializer=tf.initializers.random_uniform(0.0, 1.0)): super().__init__() if not isinstance(num_classes, int): raise ValueError("num_classes must be an integer") self._num_classes = num_classes self._num_decomps = num_decomps self._num_subsets = num_subsets self._num_mixtures = num_mixtures self._input_dist = input_dist self._num_input_mixtures = num_input_mixtures self._weight_initializer = weight_initializer self._class_latent_indicators = None self._sample_latent_indicators = None self._class_input = None self._sample_inputs = None def __repr__(self): return (type(self).__qualname__ + "(" + ("num_classes=" + str(self._num_classes)) + ", " + ("num_decomps=" + str(self._num_decomps)) + ", " + ("num_subsets=" + str(self._num_subsets)) + ", " + ("num_mixtures=" + str(self._num_mixtures)) + ", " + ("input_dist=" + str(self._input_dist)) + ", " + ("num_input_mixtures=" + str(self._num_input_mixtures)) + ", " + ("weight_init_value=" + str(self._weight_initializer)) + ")") @utils.docinherit(Model) def serialize(self, save_param_vals=True, sess=None): # Serialize the graph first data = serialize_graph(self._root, save_param_vals=save_param_vals, sess=sess) # Add model specific information # Inputs if self._sample_latent_indicators is not None: data['sample_latent_indicators'] = self._sample_latent_indicators.name data['sample_inputs'] = [(i.node.name, i.indices) for i in self._sample_inputs] if self._class_latent_indicators is not None: data['class_latent_indicators'] = self._class_latent_indicators.name if self._class_input: data['class_input'] = (self._class_input.node.name, self._class_input.indices) # Model params data['num_classes'] = self._num_classes data['num_decomps'] = self._num_decomps data['num_subsets'] = self._num_subsets data['num_mixtures'] = self._num_mixtures data['input_dist'] = self._input_dist data['num_input_mixtures'] = self._num_input_mixtures data['weight_init_value'] = self._weight_initializer return data @utils.docinherit(Model) def deserialize(self, data, load_param_vals=True, sess=None): # Deserialize the graph first nodes_by_name = {} self._root = deserialize_graph(data, load_param_vals=load_param_vals, sess=sess, nodes_by_name=nodes_by_name) # Model specific information # Inputs sample_latent_indicators = data.get('sample_latent_indicators', None) if sample_latent_indicators: self._sample_latent_indicators = nodes_by_name[sample_latent_indicators] else: self._sample_latent_indicators = None self._sample_inputs = tuple(Input(nodes_by_name[nn], i) for nn, i in data['sample_inputs']) class_latent_indicators = data.get('class_latent_indicators', None) if class_latent_indicators: self._class_latent_indicators = nodes_by_name[class_latent_indicators] else: self._class_latent_indicators = None class_input = data.get('class_input', None) if class_input: self._class_input = Input(nodes_by_name[class_input[0]], class_input[1]) else: self._class_input = None # Model params self._num_classes = data['num_classes'] self._num_decomps = data['num_decomps'] self._num_subsets = data['num_subsets'] self._num_mixtures = data['num_mixtures'] self._input_dist = data['input_dist'] self._num_input_mixtures = data['num_input_mixtures'] self._weight_initializer = data['weight_init_value'] @property def sample_latent_indicators(self): """IndicatorLeaf: IndicatorLeaf with input data sample.""" return self._sample_latent_indicators @property def class_latent_indicators(self): """IndicatorLeaf: Class indicator variables.""" return self._class_latent_indicators @property def sample_inputs(self): """list of Input: Inputs to the model providing data samples.""" return self._sample_inputs @property def class_input(self): """Input: Input providing class indicators..""" return self._class_input def build(self, *sample_inputs, class_input=None, num_vars=None, num_vals=None, seed=None): """Build the SPN graph of the model. The model can be built on top of any ``sample_inputs``. Otherwise, if no sample inputs are provided, the model will internally crate a single IndicatorLeaf node to represent the input data samples. In such case, ``num_vars`` and ``num_vals`` must be specified. Similarly, if ``class_input`` is provided, it is used as a source of class indicators of the root sum node combining sub-SPNs modeling particular classes. Otherwise, an internal IndicatorLeaf node is created for this purpose. Args: *sample_inputs (input_like): Optional. Inputs to the model providing data samples. class_input (input_like): Optional. Input providing class indicators. num_vars (int): Optional. Number of variables in each sample. Must only be provided if ``sample_inputs`` are not given. num_vals (int or list of int): Optional. Number of values of each variable. Can be a single value or a list of values, one for each of ``num_vars`` variables. Must only be provided if ``sample_inputs`` are not given. seed (int): Optional. Seed used for the dense SPN generator. Returns: Sum: Root node of the generated model. """ if not sample_inputs: if num_vars is None: raise ValueError("num_vars must be given when sample_inputs are not") if num_vals is None: raise ValueError("num_vals must be given when sample_inputs are not") if not isinstance(num_vars, int) or num_vars < 1: raise ValueError("num_vars must be an integer > 0") if not isinstance(num_vals, int) or num_vals < 1: raise ValueError("num_vals must be an integer > 0") if self._num_classes > 1: self.__info("Building a discrete dense model with %d classes" % self._num_classes) else: self.__info("Building a 1-class discrete dense model") # Create IndicatorLeaf if inputs not given if not sample_inputs: self._sample_latent_indicators = IndicatorLeaf(num_vars=num_vars, num_vals=num_vals, name="SampleIndicatorLeaf") self._sample_inputs = [Input(self._sample_latent_indicators)] else: self._sample_inputs = tuple(Input.as_input(i) for i in sample_inputs) if self._num_classes > 1: if class_input is None: self._class_latent_indicators = IndicatorLeaf(num_vars=1, num_vals=self._num_classes, name="ClassIndicatorLeaf") self._class_input = Input(self._class_latent_indicators) else: self._class_input = Input.as_input(class_input) # Generate structure dense_gen = DenseSPNGenerator(num_decomps=self._num_decomps, num_subsets=self._num_subsets, num_mixtures=self._num_mixtures, input_dist=self._input_dist, num_input_mixtures=self._num_input_mixtures, balanced=True) rnd = random.Random(seed) if self._num_classes == 1: # One-class self._root = dense_gen.generate(*self._sample_inputs, rnd=rnd, root_name='Root') else: # Multi-class: create sub-SPNs sub_spns = [] for c in range(self._num_classes): rnd_copy = random.Random() rnd_copy.setstate(rnd.getstate()) with tf.name_scope("Class%d" % c): sub_root = dense_gen.generate(*self._sample_inputs, rnd=rnd_copy) if self.__is_debug1(): self.__debug1("sub-SPN %d has %d nodes" % (c, sub_root.get_num_nodes())) sub_spns.append(sub_root) # Create root self._root = Sum(*sub_spns, latent_indicators=self._class_input, name="Root") if self.__is_debug1(): self.__debug1("SPN graph has %d nodes" % self._root.get_num_nodes()) # Generate weight nodes self.__debug1("Generating weight nodes") generate_weights(self._root, initializer=self._weight_initializer) if self.__is_debug1(): self.__debug1("SPN graph has %d nodes and %d TF ops" % ( self._root.get_num_nodes(), self._root.get_tf_graph_size())) return self._root
def __add_decompositions(self, subset_info: SubsetInfo, rnd: random.Random): """Add nodes for a single subset, i.e. an instance of ``num_decomps`` decompositions of ``subset`` into ``num_subsets`` sub-subsets with ``num_mixures`` mixtures per sub-subset. Args: subset_info(SubsetInfo): Info about the subset being decomposed. rnd (Random): A custom instance of a random number generator or ``None`` if default global instance should be used. Returns: list of SubsetInfo: Info about each new generated subset, which requires further decomposition. """ def subsubset_to_inputs_list(subsubset): """Convert sub-subsets into a list of tuples, where each tuple contains an input and a list of indices """ subsubset_list = list(next(iter(subsubset))) # Create a list of unique inputs from sub-subsets list unique_inputs = list(set(s_subset[0] for s_subset in subsubset_list)) # For each unique input, collect all associated indices # into a single list, then create a list of tuples, # where each tuple contains an unique input and it's # list of indices inputs_list = [] for unique_inp in unique_inputs: indices_list = [] for s_subset in subsubset_list: if s_subset[0] == unique_inp: indices_list.append(s_subset[1]) inputs_list.append(tuple((unique_inp, indices_list))) return inputs_list # Get subset partitions self.__debug3("Decomposing subset:\n%s", subset_info.subset) num_elems = len(subset_info.subset) num_subsubsets = min(num_elems, self.num_subsets) # Requested num subsets partitions = utils.random_partitions(subset_info.subset, num_subsubsets, self.num_decomps, balanced=self.balanced, rnd=rnd, stirling=self.__stirling) self.__debug2("Randomized %s decompositions of a subset" " of %s elements into %s sets", len(partitions), num_elems, num_subsubsets) # Generate nodes for each decomposition/partition subsubset_infos = [] for part in partitions: self.__debug2("Decomposition %s: into %s subsubsets of cardinality %s", self.__decomp_id, len(part), [len(s) for s in part]) self.__debug3("Decomposition %s subsubsets:\n%s", self.__decomp_id, part) # Handle each subsubset sums_id = 1 prod_inputs = [] for subsubset in part: if self.node_type == DenseSPNGenerator.NodeType.SINGLE: # Use single-nodes if len(subsubset) > 1: # Decomposable further # Add mixtures with tf.name_scope("Sums%s.%s" % (self.__decomp_id, sums_id)): sums = [Sum(name="Sum%s" % (i + 1)) for i in range(self.num_mixtures)] sums_id += 1 # Register the mixtures as inputs of products prod_inputs.append([(s, 0) for s in sums]) # Generate subsubset info subsubset_infos.append(DenseSPNGenerator.SubsetInfo( level=subset_info.level + 1, subset=subsubset, parents=sums)) else: # Non-decomposable if self.input_dist == DenseSPNGenerator.InputDist.RAW: # Register the content of subset as inputs to products prod_inputs.append(next(iter(subsubset))) elif self.input_dist == DenseSPNGenerator.InputDist.MIXTURE: # Add mixtures with tf.name_scope("Sums%s.%s" % (self.__decomp_id, sums_id)): sums = [Sum(name="Sum%s" % (i + 1)) for i in range(self.num_input_mixtures)] sums_id += 1 # Register the mixtures as inputs of products prod_inputs.append([(s, 0) for s in sums]) # Create an inputs list inputs_list = subsubset_to_inputs_list(subsubset) # Connect inputs to mixtures for s in sums: s.add_values(*inputs_list) else: # Use multi-nodes if len(subsubset) > 1: # Decomposable further # Add mixtures with tf.name_scope("Sums%s.%s" % (self.__decomp_id, sums_id)): sums = ParallelSums(num_sums=self.num_mixtures, name="ParallelSums%s.%s" % (self.__decomp_id, sums_id)) sums_id += 1 # Register the mixtures as inputs of PermProds prod_inputs.append(sums) # Generate subsubset info subsubset_infos.append(DenseSPNGenerator.SubsetInfo( level=subset_info.level + 1, subset=subsubset, parents=[sums])) else: # Non-decomposable if self.input_dist == DenseSPNGenerator.InputDist.RAW: # Create an inputs list inputs_list = subsubset_to_inputs_list(subsubset) if len(inputs_list) > 1: inputs_list = [Concat(*inputs_list)] # Register the content of subset as inputs to PermProds [prod_inputs.append(inp) for inp in inputs_list] elif self.input_dist == DenseSPNGenerator.InputDist.MIXTURE: # Create an inputs list inputs_list = subsubset_to_inputs_list(subsubset) # Add mixtures with tf.name_scope("Sums%s.%s" % (self.__decomp_id, sums_id)): sums = ParallelSums(*inputs_list, num_sums=self.num_input_mixtures, name="ParallelSums%s.%s" % (self.__decomp_id, sums_id)) sums_id += 1 # Register the mixtures as inputs of PermProds prod_inputs.append(sums) # Add product nodes if self.node_type == DenseSPNGenerator.NodeType.SINGLE: products = self.__add_products(prod_inputs) else: products = ([PermuteProducts(*prod_inputs, name="PermuteProducts%s" % self.__decomp_id)] if len(prod_inputs) > 1 else prod_inputs) # Connect products to each parent Sum for p in subset_info.parents: p.add_values(*products) # Increment decomposition id self.__decomp_id += 1 return subsubset_infos
def build(self): # Inputs self._latent_indicators = IndicatorLeaf(num_vars=2, num_vals=2, name="IndicatorLeaf") # Input mixtures s11 = Sum((self._latent_indicators, [0, 1]), name="Sum1.1") s11.generate_weights(tf.initializers.constant([0.4, 0.6])) s12 = Sum((self._latent_indicators, [0, 1]), name="Sum1.2") s12.generate_weights(tf.initializers.constant([0.1, 0.9])) s21 = Sum((self._latent_indicators, [2, 3]), name="Sum2.1") s21.generate_weights(tf.initializers.constant([0.7, 0.3])) s22 = Sum((self._latent_indicators, [2, 3]), name="Sum2.2") s22.generate_weights(tf.initializers.constant([0.8, 0.2])) # Components p1 = Product(s11, s21, name="Comp1") p2 = Product(s11, s22, name="Comp2") p3 = Product(s12, s22, name="Comp3") # Mixing components self._root = Sum(p1, p2, p3, name="Mixture") self._root.generate_weights(tf.initializers.constant([0.5, 0.2, 0.3])) return self._root
class Poon11NaiveMixtureModel(Model): """A simple naive Bayes mixture from the Poon&Domingos'11 paper. The model is only used for testing. The weights of the model are initialized to specific values, for which various qualities are calculated. """ def __init__(self): super().__init__() self._latent_indicators = None @utils.docinherit(Model) def serialize(save_param_vals=True, sess=None): raise NotImplementedError("Serialization not implemented") @utils.docinherit(Model) def deserialize(self, data, load_param_vals=True, sess=None): raise NotImplementedError("Serialization not implemented") @property def latent_indicators(self): """IndicatorLeaf: The IndicatorLeaf with the input variables of the model.""" return self._latent_indicators @property def true_mpe_state(self): """The true MPE state for the SPN.""" return np.array([1, 0]) @property def true_values(self): """The true values of the SPN for the :meth:`feed`.""" return np.array([[1.0], [0.75], [0.25], [0.31], [0.228], [0.082], [0.69], [0.522], [0.168]], dtype=conf.dtype.as_numpy_dtype) @property def true_mpe_values(self): """The true MPE values of the SPN for the :meth:`feed`.""" return np.array([[0.216], [0.216], [0.09], [0.14], [0.14], [0.06], [0.216], [0.216], [0.09]], dtype=conf.dtype.as_numpy_dtype) @property def feed(self): """Feed containing all possible values of the input variables.""" values = np.arange(-1, 2) points = np.array(np.meshgrid(*[values for i in range(2)])).T return points.reshape(-1, points.shape[-1]) @utils.docinherit(Model) def build(self): # Inputs self._latent_indicators = IndicatorLeaf(num_vars=2, num_vals=2, name="IndicatorLeaf") # Input mixtures s11 = Sum((self._latent_indicators, [0, 1]), name="Sum1.1") s11.generate_weights(tf.initializers.constant([0.4, 0.6])) s12 = Sum((self._latent_indicators, [0, 1]), name="Sum1.2") s12.generate_weights(tf.initializers.constant([0.1, 0.9])) s21 = Sum((self._latent_indicators, [2, 3]), name="Sum2.1") s21.generate_weights(tf.initializers.constant([0.7, 0.3])) s22 = Sum((self._latent_indicators, [2, 3]), name="Sum2.2") s22.generate_weights(tf.initializers.constant([0.8, 0.2])) # Components p1 = Product(s11, s21, name="Comp1") p2 = Product(s11, s22, name="Comp2") p3 = Product(s12, s22, name="Comp3") # Mixing components self._root = Sum(p1, p2, p3, name="Mixture") self._root.generate_weights(tf.initializers.constant([0.5, 0.2, 0.3])) return self._root