def _get_required_compnames(self): """Returns a set of names of components that are required by this Driver in order to evaluate parameters, objectives and constraints. This list will include any intermediate components in the data flow between components referenced by parameters and those referenced by objectives and/or constraints. """ if self._required_compnames is None: conns = super(Driver, self).get_expr_depends() getcomps = set([u for u, v in conns if u != self.name]) setcomps = set([v for u, v in conns if v != self.name]) full = set(setcomps) full.update(getcomps) full.update(self.list_pseudocomps()) compgraph = self.parent._depgraph.component_graph() for end in getcomps: for start in setcomps: full.update(find_all_connecting(compgraph, start, end)) self._required_compnames = full return self._required_compnames
def _get_required_compnames(self): """Returns a set of names of components that are required by this Driver in order to evaluate parameters, objectives and constraints. This list will include any intermediate components in the data flow between components referenced by parameters and those referenced by objectives and/or constraints. """ if self._required_compnames is None: # call base class version of get_expr_depends so we don't filter out # comps in our iterset. We want required names to be everything between # and including comps that we reference in any parameter, objective, or # constraint. conns = super(Driver, self).get_expr_depends() getcomps = set([u for u, v in conns if u != self.name]) setcomps = set([v for u, v in conns if v != self.name]) full = set(setcomps) full.update(getcomps) full.update(self.list_pseudocomps()) compgraph = self.get_depgraph().component_graph() for end in getcomps: for start in setcomps: full.update(find_all_connecting(compgraph, start, end)) self._required_compnames = full return self._required_compnames
def _get_required_compnames(self): """Returns a set of names of components that are required by this Driver in order to evaluate parameters, objectives and constraints. This list will include any intermediate components in the data flow between components referenced by parameters and those referenced by objectives and/or constraints. """ if self._required_compnames is None: # call base class version of get_expr_depends so we don't filter out # comps in our iterset. We want required names to be everything between # and including comps that we reference in any parameter, objective, or # constraint. conns = super(Driver, self).get_expr_depends() getcomps = set([u for u, v in conns if u != self.name]) setcomps = set([v for u, v in conns if v != self.name]) full = set(setcomps) full.update(getcomps) full.update(self.list_pseudocomps()) compgraph = self.parent._depgraph.component_graph() for end in getcomps: for start in setcomps: full.update(find_all_connecting(compgraph, start, end)) if self.name in full: full.remove(self.name) self._required_compnames = full return self._required_compnames
def _get_required_compnames(self): """Returns a set of names of components that are required by this Driver in order to evaluate parameters, objectives and constraints. This list will include any intermediate components in the data flow between components referenced by parameters and those referenced by objectives and/or constraints. """ if self._required_compnames is None: boundary_vars = self.parent.list_vars() conns = super(Driver, self).get_expr_depends() getcomps = set([u for u, v in conns if u != self.name \ if u not in boundary_vars and v not in boundary_vars]) setcomps = set([v for u, v in conns if v != self.name \ if u not in boundary_vars and v not in boundary_vars]) full = set(setcomps) full.update(getcomps) full.update(self.list_pseudocomps()) compgraph = self.parent._depgraph.component_graph() for end in getcomps: for start in setcomps: full.update(find_all_connecting(compgraph, start, end)) self._required_compnames = full return self._required_compnames
def test_find_all_connecting(self): self.assertEqual(find_all_connecting(self.dep.component_graph(), 'A','D'), set()) self.assertEqual(find_all_connecting(self.dep.component_graph(), 'A','C'), set(['A','B','C']))
def _group_nondifferentiables(self, fd=False, severed=None): """Method to find all non-differentiable blocks, and group them together, replacing them in the derivative graph with pseudo- assemblies that can finite difference their components together. fd: boolean set to True to finite difference the whole model together with fake finite difference turned off. This is mainly for checking your model's analytic derivatives. severed: list If a workflow has a cylic connection, some edges must be severed. When a cyclic workflow calls this function, it passes a list of edges so that they can be severed prior to the topological sort. """ dgraph = self._derivative_graph # If we have a cyclic workflow, we need to remove severed edges from # the derivatives graph. if severed is not None: for edge in severed: if edge in dgraph.edges(): dgraph.remove_edge(edge[0], edge[1]) cgraph = dgraph.component_graph() comps = cgraph.nodes() pas = [ dgraph.node[n]['pa_object'] for n in dgraph.nodes_iter() if n.startswith('~') and '.' not in n ] pa_excludes = set() for pa in pas: pa_excludes.update(pa._removed_comps) # Full model finite-difference, so all components go in the PA if fd == True: nondiff_groups = [comps] # Find the non-differentiable components else: # A component with no derivatives is non-differentiable nondiff = set() nondiff_groups = [] for name in comps: if name.startswith('~') or name in pa_excludes: continue # don't want nested pseudoassemblies comp = self.scope.get(name) if not hasattr(comp, 'apply_deriv') and \ not hasattr(comp, 'apply_derivT') and \ not hasattr(comp, 'provideJ'): nondiff.add(name) elif comp.force_fd is True: nondiff.add(name) elif not dgraph.node[name].get('differentiable', True): nondiff.add(name) # If a connection is non-differentiable, so are its src and # target components. for edge in dgraph.list_connections(): src, target = edge data = dgraph.node[src] # boundary vars or fake inputs/outputs if src.startswith('@') or target.startswith( '@') or '.' not in src: continue # pseudoassemblies if src.startswith('~') or target.startswith('~'): continue # Custom differentiable connections or ignored connections if data.get('data_shape'): continue # differentiable connections if is_differentiable_val(self.scope.get(src)): continue #Nothing else is differentiable else: nondiff.add(src.split('.')[0]) nondiff.add(target.split('.')[0]) #print "non-differentiable connection: ", src, target # Everything is differentiable, so return if len(nondiff) == 0: return # Groups any connected non-differentiable blocks. Each block is a # set of component names. sub = cgraph.subgraph(nondiff) for inodes in nx.connected_components(sub.to_undirected()): # Pull in any differentiable islands nodeset = set(inodes) for src in inodes: for targ in inodes: if src != targ: nodeset.update( find_all_connecting(cgraph, src, targ)) nondiff_groups.append(nodeset) for j, group in enumerate(nondiff_groups): pa_name = '~%d' % j # Create the pseudoassy pseudo = PseudoAssembly(pa_name, group, dgraph, self, fd) pseudo.add_to_graph(self.scope._depgraph, dgraph) pseudo.clean_graph(self.scope._depgraph, dgraph)
def _group_nondifferentiables(self, fd=False, severed=None): """Method to find all non-differentiable blocks, and group them together, replacing them in the derivative graph with pseudo- assemblies that can finite difference their components together. fd: boolean set to True to finite difference the whole model together with fake finite difference turned off. This is mainly for checking your model's analytic derivatives. severed: list If a workflow has a cylic connection, some edges must be severed. When a cyclic workflow calls this function, it passes a list of edges so that they can be severed prior to the topological sort. """ dgraph = self._derivative_graph # If we have a cyclic workflow, we need to remove severed edges from # the derivatives graph. if severed is not None: for edge in severed: if edge in dgraph.edges(): dgraph.remove_edge(edge[0], edge[1]) cgraph = dgraph.component_graph() comps = cgraph.nodes() pas = [dgraph.node[n]['pa_object'] for n in dgraph.nodes_iter() if n.startswith('~') and '.' not in n] pa_excludes = set() for pa in pas: pa_excludes.update(pa._removed_comps) # Full model finite-difference, so all components go in the PA if fd == True: nondiff_groups = [comps] # Find the non-differentiable components else: # A component with no derivatives is non-differentiable nondiff = set() nondiff_groups = [] for name in comps: if name.startswith('~') or name in pa_excludes: continue # don't want nested pseudoassemblies comp = self.scope.get(name) if not hasattr(comp, 'apply_deriv') and \ not hasattr(comp, 'apply_derivT') and \ not hasattr(comp, 'provideJ'): nondiff.add(name) elif comp.force_fd is True: nondiff.add(name) elif not dgraph.node[name].get('differentiable', True): nondiff.add(name) # If a connection is non-differentiable, so are its src and # target components. for edge in dgraph.list_connections(): src, target = edge data = dgraph.node[src] # boundary vars or fake inputs/outputs if src.startswith('@') or target.startswith('@') or '.' not in src: continue # pseudoassemblies if src.startswith('~') or target.startswith('~'): continue # Custom differentiable connections or ignored connections if data.get('data_shape'): continue # differentiable connections if is_differentiable_val(self.scope.get(src)): continue #Nothing else is differentiable else: nondiff.add(src.split('.')[0]) nondiff.add(target.split('.')[0]) #print "non-differentiable connection: ", src, target # Everything is differentiable, so return if len(nondiff) == 0: return # Groups any connected non-differentiable blocks. Each block is a # set of component names. sub = cgraph.subgraph(nondiff) for inodes in nx.connected_components(sub.to_undirected()): # Pull in any differentiable islands nodeset = set(inodes) for src in inodes: for targ in inodes: if src != targ: nodeset.update(find_all_connecting(cgraph, src, targ)) nondiff_groups.append(nodeset) for j, group in enumerate(nondiff_groups): pa_name = '~%d' % j # Create the pseudoassy pseudo = PseudoAssembly(pa_name, group, dgraph, self, fd) pseudo.add_to_graph(self.scope._depgraph, dgraph) pseudo.clean_graph(self.scope._depgraph, dgraph)