Пример #1
0
    def restricted(self, inputs=None, outputs=None):
        """ Return an ExecutionModel which has been restricted to the given
        inputs and outputs.

        Parameters
        ----------
        inputs : sequence of str, optional
            The names of the input variables which have changed. If not
            provided, then every possible input will be considered to have
            changed.
        outputs : sequence of str, optional
            The names of the output variables which are desired to recompute.

        Returns
        -------
        model : ExecutionModel
        """
        if inputs is not None:
            inputs = set(inputs)
        if outputs is not None:
            outputs = set(outputs)

        # Convert the names to the nodes which are directly related to them.
        input_nodes = set()
        output_nodes = set()
        for statement in self.statements:
            inbindings = set(iv.binding for iv in statement.inputs)
            if inputs is None or inputs.intersection(inbindings):
                input_nodes.add(statement)
            outbindings = set(ov.binding for ov in statement.outputs)
            if outputs is None or outputs.intersection(outbindings):
                output_nodes.add(statement)

        # Find the reachable subgraphs of both the inputs and outputs.
        dep_graph = self.dep_graph
        inreachable = graph.reverse(graph.reachable_graph(
            graph.reverse(dep_graph), input_nodes))
        outreachable = graph.reachable_graph(dep_graph, output_nodes)

        # Our desired set of statements is the intersection of these two graphs.
        in_nodes = set()
        for node, deps in inreachable.iteritems():
            in_nodes.add(node)
            for d in deps:
                in_nodes.add(d)
        intersection = set()
        for node, deps in outreachable.iteritems():
            if node in in_nodes:
                intersection.add(node)
                for d in deps:
                    if d in in_nodes:
                        intersection.add(d)

        em = self.__class__(
            statements=list(intersection),
        )
        return em
Пример #2
0
 def _base(self, graph, nodes, result, error=None):
     if error:
         self.assertRaises(error, lambda: self._base(graph, nodes, result))
     else:
         self.assertEqual(G.reachable_graph(graph, nodes), result)