def get_mpe_path(self, root): """Assemble TF operations computing the true branch counts for the MPE downward path through the SPN rooted in ``root``. Args: root (Node): The root node of the SPN graph. """ def down_fun(node, parent_vals): self._true_counts[node] = summed = self._accumulate_parents(*parent_vals) basesum_kwargs = dict( add_random=self._add_random, use_unweighted=self._use_unweighted, sample=self._sample, sample_prob=self._sample_prob) if node.is_op: kwargs = basesum_kwargs if isinstance(node, BaseSum) else dict() # Compute for inputs with tf.name_scope(node.name): return node._compute_log_mpe_path( summed, *[self._value.values[i.node] if i else None for i in node.inputs], **kwargs) # Generate values if not yet generated if not self._value.values: self._value.get_value(root) with tf.name_scope("TrueMPEPath"): # Compute the tensor to feed to the root node graph_input = self._graph_input(self._value.values[root]) # Traverse the graph computing counts self._true_counts = {} compute_graph_up_down(root, down_fun=down_fun, graph_input=graph_input)
def get_gradients(self, root): """Assemble TF operations computing the gradients of the SPN rooted in ``root``. Args: root (Node): The root node of the SPN graph. """ def down_fun(node, parent_vals): # Sum up all parent vals parent_vals = [pv for pv in parent_vals if pv is not None] if len(parent_vals) > 1: summed = tf.add_n(parent_vals, name=node.name + "_add") else: summed = parent_vals[0] self._true_gradients[node] = summed if node.is_op: # Compute for inputs if isinstance(node, BaseSum): kwargs = dict( dropconnect_keep_prob=self._dropconnect_keep_prob) else: kwargs = dict() with tf.name_scope(node.name): if self._log: return node._compute_log_gradient( summed, *[ self._value.values[i.node] if i else None for i in node.inputs ], **kwargs) else: return node._compute_log_gradient( summed, *[ self._value.values[i.node] if i else None for i in node.inputs ], **kwargs) # Generate values if not yet generated if not self._value.values: self._value.get_value(root) with tf.name_scope("Gradient"): # Compute the tensor to feed to the root node graph_input = tf.ones_like(self._value.values[root]) # Traverse the graph computing gradients self._true_gradients = {} compute_graph_up_down(root, down_fun=down_fun, graph_input=graph_input)
def get_mpe_path(self, root): """Assemble TF operations computing the branch counts for the MPE downward path through the SPN rooted in ``root``. Args: root (Node): The root node of the SPN graph. """ def down_fun(node, parent_vals): # Sum up all parent vals if len(parent_vals) > 1: summed = tf.add_n(parent_vals, name=node.name + "_add") else: summed = parent_vals[0] self._counts[node] = summed if node.is_op: # Compute for inputs with tf.name_scope(node.name): if self._log: return node._compute_log_mpe_path( summed, *[ self._value.values[i.node] if i else None for i in node.inputs ], add_random=self._add_random, use_unweighted=self._use_unweighted) else: return node._compute_mpe_path( summed, *[ self._value.values[i.node] if i else None for i in node.inputs ], add_random=self._add_random, use_unweighted=self._use_unweighted) # Generate values if not yet generated if not self._value.values: self._value.get_value(root) with tf.name_scope("MPEPath"): # Compute the tensor to feed to the root node graph_input = tf.ones_like(self._value.values[root]) # Traverse the graph computing counts self._counts = {} compute_graph_up_down(root, down_fun=down_fun, graph_input=graph_input)