Пример #1
0
    def get_value(self, root):
        """Assemble a TF operation computing the values of nodes of the SPN
        rooted in ``root``.

        Returns the operation computing the value for the ``root``. Operations
        computing values for other nodes can be obtained using :obj:`values`.

        Args:
            root (Node): The root node of the SPN graph.

        Returns:
            Tensor: A tensor of shape ``[None, num_outputs]``, where the first
            dimension corresponds to the batch size.
        """
        def fun(node, *args):
            with tf.name_scope(node.name):
                if (self._inference_type == InferenceType.MARGINAL
                    or (self._inference_type is None and
                        node.inference_type == InferenceType.MARGINAL)):
                    return node._compute_value(*args)
                else:
                    return node._compute_mpe_value(*args)

        self._values = {}
        with tf.name_scope("Value"):
            return compute_graph_up(root, val_fun=fun,
                                    all_values=self._values)
Пример #2
0
    def get_value(self, root):
        """Assemble TF operations computing the log values of nodes of the SPN
        rooted in ``root``.

        Returns the operation computing the log value for the ``root``.
        Operations computing log values for other nodes can be obtained using
        :obj:`values`.

        Args:
            root: Root node of the SPN.

        Returns:
            Tensor: A tensor of shape ``[None, num_outputs]``, where the first
            dimension corresponds to the batch size.
        """
        def fun(node, *args):
            if self._dropconnect_keep_prob and isinstance(node, BaseSum):
                kwargs = dict(
                    dropconnect_keep_prob=self._dropconnect_keep_prob)
            else:
                kwargs = dict()
            with tf.name_scope(node.name):
                if (self._inference_type == InferenceType.MARGINAL
                    or (self._inference_type is None and
                        node.inference_type == InferenceType.MARGINAL)):
                    return node._compute_log_value(*args, **kwargs)
                else:
                    return node._compute_log_mpe_value(*args, **kwargs)

        self._values = {}
        with tf.name_scope(self._name):
            return compute_graph_up(root, val_fun=fun, all_values=self._values)
Пример #3
0
    def get_scope(self):
        """Get the scope of each output value of this node.

        Returns:
            list of Scope: A list of length ``out_size`` containing scopes of
                           each output of this node.
        """
        return compute_graph_up(
            self, (lambda node, *args: node._compute_scope(*args)))
Пример #4
0
    def get_out_size(self):
        """Get the size of the output of this node.  The size might depend on
        the inputs of this node and might change if new inputs are added.

        Returns:
            int: The size of the output.
        """
        return compute_graph_up(
            self, (lambda node, *args: node._compute_out_size(*args)),
            (lambda node: node._const_out_size))
Пример #5
0
    def is_valid(self):
        """Check if the SPN rooted in this node is complete and decomposable.
        If a node has multiple outputs, it is considered valid if all outputs
        of that node come from a valid SPN.

        Returns:
            bool: ``True`` if this SPN is complete and decomposable.
        """
        return (compute_graph_up(
            self, (lambda node, *args: node._compute_valid(*args)))
                is not None)
Пример #6
0
    def get_depth(self):
        """Get depth of the SPN.

        Returns:
            int: The depth of the SPN
        """
        def _increment(_, *args):
            not_none = [a for a in args if a is not None]
            return max(not_none) + 1 if len(not_none) else 0

        return compute_graph_up(self, val_fun=_increment)
Пример #7
0
    def generate(self, root):
        """Generate the weight nodes.

        Args:
            root: The root node of the SPN graph.
        """
        def gen(node, *input_out_sizes):
            if isinstance(node, Sum):
                self._weights[node] = node.generate_weights(
                    init_value=self.init_value,
                    trainable=self.trainable,
                    input_sizes=node._gather_input_sizes(*input_out_sizes))
            return node._compute_out_size(*input_out_sizes)

        with tf.name_scope("Weights"):
            self._weights = {}
            # Traverse the graph and compute the out_size for each node
            return compute_graph_up(root, val_fun=gen)
Пример #8
0
    def get_input_sizes(self, *input_tensors):
        """Get the sizes of inputs of this node (as selected by indices).
        If the input is disconnected, ``None`` is returned for that input.

        Args:
            *input_tensors (Tensor): Optional tensors with values produced by
                the nodes connected to the inputs. If not given, the input sizes
                will be computed by traversing the graph. If given, the input
                sizes will be computed based on the sizes of ``input_tensors``.
                If ``None`` is given for an input, ``None`` is returned for that
                input.

        Returns:
            list of int: For each input, the size of the input.
        """
        def val_fun(node, *args):
            if node is self:
                return self._gather_input_sizes(*args)
            else:
                return node._compute_out_size(*args)

        def const_fun(node):
            if node is self:
                # Make sure to go through the children of this node
                return False
            else:
                return node._const_out_size

        if input_tensors:
            if len(self.inputs) != len(input_tensors):
                raise ValueError("Number of 'input_tensors' must be the same"
                                 " as the number of inputs.")
            return tuple(
                None if not inpt or tensor is None else inpt.get_size(tensor)
                for inpt, tensor in zip(self.inputs, input_tensors))
        else:
            return compute_graph_up(self, val_fun=val_fun, const_fun=const_fun)
Пример #9
0
def display_spn_graph(root, skip_params=False):
    """Visualize an SPN graph in IPython/Jupyter.

    Args:
        root (Node): Root of the SPN to visualize.
        skip_params (bool): If ``True``, parameter nodes will not be shown.
    """
    # Graph description for HTML generation
    links = []
    nodes = []
    node_types = []
    leaf_counter = [1]

    def add_node(node, *input_out_sizes):
        """Add a node to the graph description. This function also computes the
        size of the output while traversing the graph. """
        # Get a unique node type number, for any node including leafs
        try:
            node_type = node_types.index(node.__class__)
        except:
            node_type = len(node_types)
            node_types.append(node.__class__)
        # Param and var nodes are added when processing an op node
        if node.is_op:
            # Create this node
            nodes.append({
                "id": node.name,
                "name": node.name,
                "type": node_type,
                "tooltip": str(node)
            })
            # Create links to inputs (and param/var nodes)
            for inpt, size in zip(node.inputs,
                                  node._gather_input_sizes(*input_out_sizes)):
                if inpt:
                    if inpt.is_op:
                        # WARNING: Currently if a node has two inputs from the
                        # same node, they will be added correctly, but displayed
                        # on top of each other
                        links.append({
                            "source": inpt.node.name,
                            "target": node.name,
                            "value": size
                        })
                    elif not skip_params or not inpt.is_param:
                        # Unique id for a leaf node
                        leaf_id = inpt.node.name + "_" + str(leaf_counter[0])
                        # Add indices in the name of the node
                        leaf_name = (inpt.node.name if inpt.indices is None
                                     else inpt.node._name + str(inpt.indices))
                        leaf_type = node_types.index(
                            inpt.node.__class__)  # Must exist
                        leaf_counter[0] += 1
                        nodes.append({
                            "id": leaf_id,
                            "name": leaf_name,
                            "type": leaf_type,
                            "tooltip": str(inpt.node)
                        })
                        links.append({
                            "source": leaf_id,
                            "target": node.name,
                            "value": size
                        })

        # Return computed outputs size
        return node._compute_out_size(*input_out_sizes)

    # Compute graph & build HTML
    compute_graph_up(root, val_fun=add_node)
    html = _html_graph(nodes, links)
    import IPython.display  # Import only if needed
    IPython.display.display(IPython.display.HTML(html))