Beispiel #1
0
    def compile(cls, source_net, compiled_net):
        """Add observed nodes to the computation graph.

        Parameters
        ----------
        source_net : nx.DiGraph
        compiled_net : nx.DiGraph

        Returns
        -------
        compiled_net : nx.Digraph

        """
        logger.debug("{} compiling...".format(cls.__name__))

        observable = []
        uses_observed = []

        for node in nx.topological_sort(source_net):
            state = source_net.nodes[node]['attr_dict']
            if state.get('_observable'):
                observable.append(node)
                cls.make_observed_copy(node, compiled_net)
            elif state.get('_uses_observed'):
                uses_observed.append(node)
                obs_node = cls.make_observed_copy(node, compiled_net,
                                                  args_to_tuple)
                # Make edge to the using node
                compiled_net.add_edge(obs_node, node, param='observed')
            else:
                continue

            # Copy the edges
            if not state.get('_stochastic'):
                obs_node = observed_name(node)
                for parent in source_net.predecessors(node):
                    if parent in observable:
                        link_parent = observed_name(parent)
                    else:
                        link_parent = parent

                    compiled_net.add_edge(link_parent, obs_node,
                                          **source_net[parent][node].copy())

        # Check that there are no stochastic nodes in the ancestors
        for node in uses_observed:
            # Use the observed version to query observed ancestors in the compiled_net
            obs_node = observed_name(node)
            for ancestor_node in nx.ancestors(compiled_net, obs_node):
                if '_stochastic' in source_net.nodes.get(ancestor_node, {}):
                    raise ValueError(
                        "Observed nodes must be deterministic. Observed "
                        "data depends on a non-deterministic node {}.".format(
                            ancestor_node))

        return compiled_net
Beispiel #2
0
    def load(cls, context, compiled_net, batch_index):
        """Add the observed data to the `compiled_net`.

        Parameters
        ----------
        context : ComputationContext
        compiled_net : nx.DiGraph
        batch_index : int

        Returns
        -------
        net : nx.DiGraph
            Loaded net, which is the `compiled_net` that has been loaded with data that
            can depend on the batch_index.

        """
        observed = compiled_net.graph['observed']

        for name, obs in observed.items():
            obs_name = observed_name(name)
            if not compiled_net.has_node(obs_name):
                continue
            compiled_net.node[obs_name] = dict(output=obs)

        del compiled_net.graph['observed']
        return compiled_net
Beispiel #3
0
    def make_observed_copy(cls, node, compiled_net, operation=None):
        """Make a renamed copy of an observed node and add it to `compiled_net`.

        Parameters
        ----------
        node : str
        compiled_net : nx.DiGraph
        operation : callable, optional

        Returns
        -------
        str

        """
        obs_node = observed_name(node)

        if compiled_net.has_node(obs_node):
            raise ValueError(
                "Observed node {} already exists!".format(obs_node))

        if operation is None:
            compiled_dict = compiled_net.nodes[node].copy()
        else:
            compiled_dict = dict(operation=operation)
        compiled_net.add_node(obs_node, **compiled_dict)
        return obs_node
Beispiel #4
0
    def load(cls, context, compiled_net, batch_index):
        for name, obs in context.observed.items():
            obs_name = observed_name(name)
            if not compiled_net.has_node(obs_name):
                continue
            compiled_net.node[obs_name] = dict(output=obs)

        return compiled_net
Beispiel #5
0
    def make_observed_copy(cls, node, compiled_net, operation=None):
        obs_node = observed_name(node)

        if compiled_net.has_node(obs_node):
            raise ValueError(
                "Observed node {} already exists!".format(obs_node))

        if operation is None:
            compiled_dict = compiled_net.node[node].copy()
        else:
            compiled_dict = dict(operation=operation)

        compiled_net.add_node(obs_node, compiled_dict)
        return obs_node
 def observed(self):
     obs_name = observed_name(self.name)
     result = self.model.generate(0, obs_name)
     return result[obs_name]