Пример #1
0
    def _get_required_compnames(self):
        """Returns a set of names of components that are required by
        this Driver in order to evaluate parameters, objectives
        and constraints.  This list will include any intermediate
        components in the data flow between components referenced by
        parameters and those referenced by objectives and/or constraints.
        """
        if self._required_compnames is None:
            conns = super(Driver, self).get_expr_depends()
            getcomps = set([u for u, v in conns if u != self.name])
            setcomps = set([v for u, v in conns if v != self.name])

            full = set(setcomps)
            full.update(getcomps)
            full.update(self.list_pseudocomps())

            compgraph = self.parent._depgraph.component_graph()

            for end in getcomps:
                for start in setcomps:
                    full.update(find_all_connecting(compgraph, start, end))

            self._required_compnames = full

        return self._required_compnames
Пример #2
0
    def _get_required_compnames(self):
        """Returns a set of names of components that are required by
        this Driver in order to evaluate parameters, objectives
        and constraints.  This list will include any intermediate
        components in the data flow between components referenced by
        parameters and those referenced by objectives and/or constraints.
        """
        if self._required_compnames is None:
            # call base class version of get_expr_depends so we don't filter out
            # comps in our iterset.  We want required names to be everything between
            # and including comps that we reference in any parameter, objective, or
            # constraint.
            conns = super(Driver, self).get_expr_depends()

            getcomps = set([u for u, v in conns if u != self.name])
            setcomps = set([v for u, v in conns if v != self.name])

            full = set(setcomps)
            full.update(getcomps)
            full.update(self.list_pseudocomps())

            compgraph = self.get_depgraph().component_graph()

            for end in getcomps:
                for start in setcomps:
                    full.update(find_all_connecting(compgraph, start, end))

            self._required_compnames = full

        return self._required_compnames
Пример #3
0
    def _get_required_compnames(self):
        """Returns a set of names of components that are required by
        this Driver in order to evaluate parameters, objectives
        and constraints.  This list will include any intermediate
        components in the data flow between components referenced by
        parameters and those referenced by objectives and/or constraints.
        """
        if self._required_compnames is None:
            # call base class version of get_expr_depends so we don't filter out
            # comps in our iterset.  We want required names to be everything between
            # and including comps that we reference in any parameter, objective, or
            # constraint.
            conns = super(Driver, self).get_expr_depends()

            getcomps = set([u for u, v in conns if u != self.name])
            setcomps = set([v for u, v in conns if v != self.name])

            full = set(setcomps)
            full.update(getcomps)
            full.update(self.list_pseudocomps())

            compgraph = self.parent._depgraph.component_graph()

            for end in getcomps:
                for start in setcomps:
                    full.update(find_all_connecting(compgraph, start, end))

            if self.name in full:
                full.remove(self.name)
            self._required_compnames = full

        return self._required_compnames
Пример #4
0
    def _get_required_compnames(self):
        """Returns a set of names of components that are required by
        this Driver in order to evaluate parameters, objectives
        and constraints.  This list will include any intermediate
        components in the data flow between components referenced by
        parameters and those referenced by objectives and/or constraints.
        """
        if self._required_compnames is None:
            boundary_vars = self.parent.list_vars()
            conns = super(Driver, self).get_expr_depends()
            getcomps = set([u for u, v in conns if u != self.name \
                            if u not in boundary_vars and v not in boundary_vars])
            setcomps = set([v for u, v in conns if v != self.name \
                            if u not in boundary_vars and v not in boundary_vars])

            full = set(setcomps)
            full.update(getcomps)
            full.update(self.list_pseudocomps())

            compgraph = self.parent._depgraph.component_graph()

            for end in getcomps:
                for start in setcomps:
                    full.update(find_all_connecting(compgraph, start, end))

            self._required_compnames = full

        return self._required_compnames
Пример #5
0
 def test_find_all_connecting(self):
     self.assertEqual(find_all_connecting(self.dep.component_graph(), 'A','D'), set())
     self.assertEqual(find_all_connecting(self.dep.component_graph(), 'A','C'), 
                      set(['A','B','C']))
Пример #6
0
    def _group_nondifferentiables(self, fd=False, severed=None):
        """Method to find all non-differentiable blocks, and group them
        together, replacing them in the derivative graph with pseudo-
        assemblies that can finite difference their components together.

        fd: boolean
            set to True to finite difference the whole model together with
            fake finite difference turned off. This is mainly for checking
            your model's analytic derivatives.

        severed: list
            If a workflow has a cylic connection, some edges must be severed.
            When a cyclic workflow calls this function, it passes a list of
            edges so that they can be severed prior to the topological sort.
        """

        dgraph = self._derivative_graph

        # If we have a cyclic workflow, we need to remove severed edges from
        # the derivatives graph.
        if severed is not None:
            for edge in severed:
                if edge in dgraph.edges():
                    dgraph.remove_edge(edge[0], edge[1])

        cgraph = dgraph.component_graph()
        comps = cgraph.nodes()
        pas = [
            dgraph.node[n]['pa_object'] for n in dgraph.nodes_iter()
            if n.startswith('~') and '.' not in n
        ]
        pa_excludes = set()
        for pa in pas:
            pa_excludes.update(pa._removed_comps)

        # Full model finite-difference, so all components go in the PA
        if fd == True:
            nondiff_groups = [comps]

        # Find the non-differentiable components
        else:

            # A component with no derivatives is non-differentiable
            nondiff = set()
            nondiff_groups = []

            for name in comps:
                if name.startswith('~') or name in pa_excludes:
                    continue  # don't want nested pseudoassemblies
                comp = self.scope.get(name)
                if not hasattr(comp, 'apply_deriv') and \
                   not hasattr(comp, 'apply_derivT') and \
                   not hasattr(comp, 'provideJ'):
                    nondiff.add(name)
                elif comp.force_fd is True:
                    nondiff.add(name)
                elif not dgraph.node[name].get('differentiable', True):
                    nondiff.add(name)

            # If a connection is non-differentiable, so are its src and
            # target components.
            for edge in dgraph.list_connections():
                src, target = edge
                data = dgraph.node[src]

                # boundary vars or fake inputs/outputs
                if src.startswith('@') or target.startswith(
                        '@') or '.' not in src:
                    continue

                # pseudoassemblies
                if src.startswith('~') or target.startswith('~'):
                    continue

                # Custom differentiable connections or ignored connections
                if data.get('data_shape'):
                    continue

                # differentiable connections
                if is_differentiable_val(self.scope.get(src)):
                    continue

                #Nothing else is differentiable
                else:
                    nondiff.add(src.split('.')[0])
                    nondiff.add(target.split('.')[0])
                    #print "non-differentiable connection: ", src, target

            # Everything is differentiable, so return
            if len(nondiff) == 0:
                return

            # Groups any connected non-differentiable blocks. Each block is a
            # set of component names.
            sub = cgraph.subgraph(nondiff)
            for inodes in nx.connected_components(sub.to_undirected()):

                # Pull in any differentiable islands
                nodeset = set(inodes)
                for src in inodes:
                    for targ in inodes:
                        if src != targ:
                            nodeset.update(
                                find_all_connecting(cgraph, src, targ))

                nondiff_groups.append(nodeset)

        for j, group in enumerate(nondiff_groups):
            pa_name = '~%d' % j

            # Create the pseudoassy
            pseudo = PseudoAssembly(pa_name, group, dgraph, self, fd)

            pseudo.add_to_graph(self.scope._depgraph, dgraph)
            pseudo.clean_graph(self.scope._depgraph, dgraph)
Пример #7
0
    def _group_nondifferentiables(self, fd=False, severed=None):
        """Method to find all non-differentiable blocks, and group them
        together, replacing them in the derivative graph with pseudo-
        assemblies that can finite difference their components together.

        fd: boolean
            set to True to finite difference the whole model together with
            fake finite difference turned off. This is mainly for checking
            your model's analytic derivatives.

        severed: list
            If a workflow has a cylic connection, some edges must be severed.
            When a cyclic workflow calls this function, it passes a list of
            edges so that they can be severed prior to the topological sort.
        """

        dgraph = self._derivative_graph

        # If we have a cyclic workflow, we need to remove severed edges from
        # the derivatives graph.
        if severed is not None:
            for edge in severed:
                if edge in dgraph.edges():
                    dgraph.remove_edge(edge[0], edge[1])

        cgraph = dgraph.component_graph()
        comps = cgraph.nodes()
        pas = [dgraph.node[n]['pa_object']
               for n in dgraph.nodes_iter()
                     if n.startswith('~') and '.' not in n]
        pa_excludes = set()
        for pa in pas:
            pa_excludes.update(pa._removed_comps)

        # Full model finite-difference, so all components go in the PA
        if fd == True:
            nondiff_groups = [comps]

        # Find the non-differentiable components
        else:

            # A component with no derivatives is non-differentiable
            nondiff = set()
            nondiff_groups = []

            for name in comps:
                if name.startswith('~') or name in pa_excludes:
                    continue  # don't want nested pseudoassemblies
                comp = self.scope.get(name)
                if not hasattr(comp, 'apply_deriv') and \
                   not hasattr(comp, 'apply_derivT') and \
                   not hasattr(comp, 'provideJ'):
                    nondiff.add(name)
                elif comp.force_fd is True:
                    nondiff.add(name)
                elif not dgraph.node[name].get('differentiable', True):
                    nondiff.add(name)

            # If a connection is non-differentiable, so are its src and
            # target components.
            for edge in dgraph.list_connections():
                src, target = edge
                data = dgraph.node[src]

                # boundary vars or fake inputs/outputs
                if src.startswith('@') or target.startswith('@') or '.' not in src:
                    continue

                # pseudoassemblies
                if src.startswith('~') or target.startswith('~'):
                    continue

                # Custom differentiable connections or ignored connections
                if data.get('data_shape'):
                    continue

                # differentiable connections
                if is_differentiable_val(self.scope.get(src)):
                    continue

                #Nothing else is differentiable
                else:
                    nondiff.add(src.split('.')[0])
                    nondiff.add(target.split('.')[0])
                    #print "non-differentiable connection: ", src, target

            # Everything is differentiable, so return
            if len(nondiff) == 0:
                return

            # Groups any connected non-differentiable blocks. Each block is a
            # set of component names.
            sub = cgraph.subgraph(nondiff)
            for inodes in nx.connected_components(sub.to_undirected()):

                # Pull in any differentiable islands
                nodeset = set(inodes)
                for src in inodes:
                    for targ in inodes:
                        if src != targ:
                            nodeset.update(find_all_connecting(cgraph, src, targ))

                nondiff_groups.append(nodeset)

        for j, group in enumerate(nondiff_groups):
            pa_name = '~%d' % j

            # Create the pseudoassy
            pseudo = PseudoAssembly(pa_name, group,
                                    dgraph, self, fd)

            pseudo.add_to_graph(self.scope._depgraph, dgraph)
            pseudo.clean_graph(self.scope._depgraph, dgraph)