def test_connected_expr(self): ConnectedExprEvaluator("var1[x]", self.top)._parse() try: ConnectedExprEvaluator("var1[x]", self.top, is_dest=True)._parse() except Exception as err: self.assertEqual(str(err), "bad destination expression 'var1[x]':" " only constant indices are allowed for arrays and slices") else: self.fail("Exception expected") ConnectedExprEvaluator("var1(2.3)", self.top)._parse() try: ConnectedExprEvaluator("var1(2.3)", self.top, is_dest=True)._parse() except Exception as err: self.assertEqual(str(err), "bad destination expression 'var1(2.3)':" " not assignable") else: self.fail("Exception expected") ConnectedExprEvaluator("var1[1:5:2]", self.top)._parse() ConnectedExprEvaluator("var1[1:5:2]", self.top, is_dest=True)._parse() ConnectedExprEvaluator("var1[1:x:2]", self.top)._parse() try: ConnectedExprEvaluator("var1[1:x:2]", self.top, is_dest=True)._parse() except Exception as err: self.assertEqual(str(err), "bad destination expression" " 'var1[1:x:2]': only constant indices are allowed" " for arrays and slices") else: self.fail("Exception expected")
def check_connect(self, src, dest, scope): """Check validity of connecting a source expression to a destination expression.""" if self.get_source(dest) is not None: scope.raise_exception("'%s' is already connected to source '%s'" % (dest, self.get_source(dest)), RuntimeError) destexpr = ConnectedExprEvaluator(dest, scope, getter='get_wrapped_attr', is_dest=True) srcexpr = ConnectedExprEvaluator(src, scope, getter='get_wrapped_attr') srccomps = srcexpr.get_referenced_compnames() destcomps = destexpr.get_referenced_compnames() if destcomps and destcomps.pop() in srccomps: raise RuntimeError("'%s' and '%s' refer to the same component." % (src, dest)) return srcexpr, destexpr
def __init__(self, *args, **kwargs): super(Objective, self).__init__(*args, **kwargs) self._pseudo = None unresolved_vars = self.get_unresolved() if unresolved_vars: msg = "Can't add objective '{0}' because of invalid variables {1}" raise ConnectedExprEvaluator._invalid_expression_error(unresolved_vars, self.text, msg) self._pseudo = PseudoComponent(self.scope, self, pseudo_type='objective') self.pcomp_name = self._pseudo.name
def check_connect(self, src, dest, scope): """Check validity of connecting a source expression to a destination expression, and determine if we need to create links to pseudocomps. """ if self.get_source(dest) is not None: scope.raise_exception("'%s' is already connected to source '%s'" % (dest, self.get_source(dest)), RuntimeError) destexpr = ConnectedExprEvaluator(dest, scope, is_dest=True) srcexpr = ConnectedExprEvaluator(src, scope, getter='get_attr') srccomps = srcexpr.get_referenced_compnames() destcomps = list(destexpr.get_referenced_compnames()) if destcomps and destcomps[0] in srccomps: raise RuntimeError("'%s' and '%s' refer to the same component." % (src, dest)) return srcexpr, destexpr, self._needs_pseudo(scope, srcexpr, destexpr)
def add_objective(self, expr, name=None, scope=None): """Adds an objective to the driver. expr: string String containing the objective expression. name: string (optional) Name to be used to refer to the objective in place of the expression string. scope: object (optional) The object to be used as the scope when evaluating the expression. """ if self._max_objectives > 0 and \ len(self._objectives) >= self._max_objectives: self.parent.raise_exception( "Can't add objective '%s'. Only %d" " objectives are allowed" % (expr, self._max_objectives), RuntimeError) expr = _remove_spaces(expr) if expr in self._objectives: self.parent.raise_exception( "Trying to add objective '%s' to" " driver, but it's already there" % expr, AttributeError) if name is not None and name in self._objectives: self.parent.raise_exception( "Trying to add objective '%s' to" " driver using name '%s', but name is" " already used" % (expr, name), AttributeError) scope = self._get_scope(scope) expreval = Objective(expr, scope) unresolved_vars = expreval.get_unresolved() if unresolved_vars: msg = "Can't add objective '{0}' because of invalid variables {1}" error = ConnectedExprEvaluator._invalid_expression_error( unresolved_vars, expreval.text, msg) self.parent.raise_exception(str(error), type(error)) name = expr if name is None else name expreval.activate() self._objectives[name] = expreval self.parent.config_changed()
def add_objective(self, expr, name=None, scope=None): """Adds an objective to the driver. expr: string String containing the objective expression. name: string (optional) Name to be used to refer to the objective in place of the expression string. scope: object (optional) The object to be used as the scope when evaluating the expression. """ if self._max_objectives > 0 and \ len(self._objectives) >= self._max_objectives: self.parent.raise_exception("Can't add objective '%s'. Only %d" " objectives are allowed" % (expr, self._max_objectives), RuntimeError) expr = _remove_spaces(expr) if expr in self._objectives: self.parent.raise_exception("Trying to add objective '%s' to" " driver, but it's already there" % expr, AttributeError) if name is not None and name in self._objectives: self.parent.raise_exception("Trying to add objective '%s' to" " driver using name '%s', but name is" " already used" % (expr, name), AttributeError) scope = self._get_scope(scope) expreval = Objective(expr, scope) unresolved_vars = expreval.get_unresolved() if unresolved_vars: msg = "Can't add objective '{0}' because of invalid variables {1}" error = ConnectedExprEvaluator._invalid_expression_error(unresolved_vars, expreval.text, msg) self.parent.raise_exception(str(error), type(error)) name = expr if name is None else name expreval.activate() self._objectives[name] = expreval self.parent.config_changed()
def add_response(self, expr, name=None, scope=None): """Adds a response to the driver. expr: string String containing the response expression. name: string (optional) Name to be used to refer to the response in place of the expression string. scope: object (optional) The object to be used as the scope when evaluating the expression. """ expr = _remove_spaces(expr) if expr in self._responses: self.parent.raise_exception( "Trying to add response '%s' to" " driver, but it's already there" % expr, AttributeError) if name is not None and name in self._responses: self.parent.raise_exception( "Trying to add response '%s' to" " driver using name '%s', but name is" " already used" % (expr, name), AttributeError) scope = self._get_scope(scope) try: expreval = Response(expr, scope) unresolved_vars = expreval.get_unresolved() except AttributeError: unresolved_vars = [expr] if unresolved_vars: msg = "Can't add response '{0}' because of invalid variables {1}" error = ConnectedExprEvaluator._invalid_expression_error( unresolved_vars, expr, msg) self.parent.raise_exception(str(error), type(error)) name = expr if name is None else name #expreval.activate(self.parent) self._responses[name] = expreval self.parent.config_changed()
def add_response(self, expr, name=None, scope=None): """Adds a response to the driver. expr: string String containing the response expression. name: string (optional) Name to be used to refer to the response in place of the expression string. scope: object (optional) The object to be used as the scope when evaluating the expression. """ expr = _remove_spaces(expr) if expr in self._responses: self.parent.raise_exception( "Trying to add response '%s' to" " driver, but it's already there" % expr, AttributeError ) if name is not None and name in self._responses: self.parent.raise_exception( "Trying to add response '%s' to" " driver using name '%s', but name is" " already used" % (expr, name), AttributeError, ) scope = self._get_scope(scope) try: expreval = Response(expr, scope) unresolved_vars = expreval.get_unresolved() except AttributeError: unresolved_vars = [expr] if unresolved_vars: msg = "Can't add response '{0}' because of invalid variables {1}" error = ConnectedExprEvaluator._invalid_expression_error(unresolved_vars, expr, msg) self.parent.raise_exception(str(error), type(error)) name = expr if name is None else name # expreval.activate(self.parent) self._responses[name] = expreval self.parent.config_changed()
def check_connect(self, src, dest, scope): """Check validity of connecting a source expression to a destination expression, and determine if we need to create links to pseudocomps. """ if self.get_source(dest) is not None: scope.raise_exception( "'%s' is already connected to source '%s'" % (dest, self.get_source(dest)), RuntimeError) destexpr = ConnectedExprEvaluator(dest, scope, is_dest=True) srcexpr = ConnectedExprEvaluator(src, scope, getter='get_attr_w_copy') srccomps = srcexpr.get_referenced_compnames() destcomps = list(destexpr.get_referenced_compnames()) if destcomps and destcomps[0] in srccomps: raise RuntimeError("'%s' and '%s' refer to the same component." % (src, dest)) try: return srcexpr, destexpr, self._needs_pseudo(srcexpr, destexpr) except AttributeError as err: exc_type, value, traceback = sys.exc_info() invalid_vars = srcexpr.get_unresolved() + destexpr.get_unresolved() parts = invalid_vars[0].rsplit('.', 1) parent = repr(scope.name) if scope.name else 'top level assembly' vname = repr(parts[0]) if len(parts) > 1: parent = repr(parts[0]) vname = repr(parts[1]) msg = "{parent} has no variable {vname}" msg = msg.format(parent=parent, vname=vname) raise AttributeError, AttributeError(msg), traceback
def check_connect(self, src, dest, scope): """Check validity of connecting a source expression to a destination expression.""" if self.get_source(dest) is not None: scope.raise_exception( "'%s' is already connected to source '%s'" % (dest, self.get_source(dest)), RuntimeError) destexpr = ConnectedExprEvaluator(dest, scope, getter='get_wrapped_attr', is_dest=True) srcexpr = ConnectedExprEvaluator(src, scope, getter='get_wrapped_attr') srccomps = srcexpr.get_referenced_compnames() destcomps = destexpr.get_referenced_compnames() if destcomps and destcomps.pop() in srccomps: raise RuntimeError("'%s' and '%s' refer to the same component." % (src, dest)) return srcexpr, destexpr
def check_connect(self, src, dest, scope): """Check validity of connecting a source expression to a destination expression, and determine if we need to create links to pseudocomps. """ if self.get_source(dest) is not None: scope.raise_exception( "'%s' is already connected to source '%s'" % (dest, self.get_source(dest)), RuntimeError) destexpr = ConnectedExprEvaluator(dest, scope, is_dest=True) srcexpr = ConnectedExprEvaluator(src, scope, getter='get_attr') srccomps = srcexpr.get_referenced_compnames() destcomps = list(destexpr.get_referenced_compnames()) if destcomps and destcomps[0] in srccomps: raise RuntimeError("'%s' and '%s' refer to the same component." % (src, dest)) return srcexpr, destexpr, self._needs_pseudo(scope, srcexpr, destexpr)
class PseudoComponent(object): """A 'fake' component that is constructed from an ExprEvaluator. This fake component can be added to a dependency graph and executed along with 'real' components. """ implements(IComponent, IPseudoComp) def __init__(self, parent, srcexpr, destexpr=None, translate=True, pseudo_type=None, subtype=None, exprobject=None): if destexpr is None: destexpr = DummyExpr() self._parent = None self.parent = parent self.name = _get_new_name(parent) self._inmap = {} # mapping of component vars to our inputs self._meta = {} self._inputs = [] # Flags and caching used by the derivatives calculation self.force_fd = False self._provideJ_bounds = None self._pseudo_type = pseudo_type # a string indicating the type of pseudocomp # this is, e.g., 'units', 'constraint', 'objective', # or 'multi_var_expr' self._subtype = subtype # for constraints, 'equality' or 'inequality' self._exprobj = exprobject # object responsible for creation of this pcomp, e.g. a Constraint self._orig_src = srcexpr.text self._orig_dest = destexpr.text self.Jsize = None self.mpi = MPI_info() varmap = {} rvarmap = {} for i, ref in enumerate(srcexpr.ordered_refs()): in_name = 'in%d' % i self._inputs.append(in_name) self._inmap[ref] = in_name varmap[ref] = in_name rvarmap.setdefault(_get_varname(ref), set()).add(ref) setattr(self, in_name, 0.) refs = list(destexpr.refs()) if refs: if len(refs) == 1: setattr(self, 'out0', None) else: raise RuntimeError("output of PseudoComponent must reference" " only one variable") varmap[refs[0]] = 'out0' rvarmap.setdefault(_get_varname(refs[0]), set()).add(refs[0]) noflat = False for name, meta in srcexpr.get_metadata(): # If any input is noflat, then the output must be too. if 'noflat' in meta: noflat = True for rname in rvarmap[name]: self._meta[varmap[rname]] = meta for name, meta in destexpr.get_metadata(): if noflat and 'noflat' not in meta: meta['noflat'] = True for rname in rvarmap[name]: self._meta[varmap[rname]] = meta if translate: xformed_src = transform_expression(srcexpr.text, self._inmap) else: xformed_src = srcexpr.text out_units = self._meta['out0'].get('units') pq = None if out_units is not None: # evaluate the src expression using UnitsOnlyPQ objects tmpdict = {} # First, replace values with UnitsOnlyPQ objects for inp in self._inputs: units = self._meta[inp].get('units') if units: tmpdict[inp] = UnitsOnlyPQ(0., units) else: tmpdict[inp] = 0. pq = eval(xformed_src, _expr_dict, tmpdict) self._srcunits = pq.unit unitnode = ast.parse(xformed_src) try: unitxform = unit_xform(unitnode, self._srcunits, out_units) except Exception as err: raise TypeError("Incompatible units for '%s' and '%s': %s" % (srcexpr.text, destexpr.text, err)) unit_src = print_node(unitxform) xformed_src = unit_src else: self._srcunits = None self._srcexpr = ConnectedExprEvaluator(xformed_src, scope=self) # this is just the equation string (for debugging) if self._orig_dest: self._outdests = [self._orig_dest] if pq is None: sunit = dunit = '' else: sunit = "'%s'" % pq.get_unit_name() dunit = "'%s'" % out_units self._orig_expr = "%s %s -> %s %s" % (self._orig_src, sunit, self._orig_dest, dunit) else: self._outdests = [] self._orig_expr = self._orig_src self.missing_deriv_policy = 'error' self._negate = False def __getstate__(self): state = self.__dict__.copy() state['_parent'] = self.parent return state def __setstate__(self, state): self.__dict__.update(state) self.parent = state['_parent'] @property def parent(self): """ Our parent assembly. """ return None if self._parent is None else self._parent() @parent.setter def parent(self, parent): self._parent = None if parent is None else weakref.ref(parent) def check_config(self, strict=False): pass def cpath_updated(self): pass def is_differentiable(self): """Return True if analytical derivatives can be computed for this Component. """ return True def get_pathname(self, rel_to_scope=None): """ Return full pathname to this object, relative to scope *rel_to_scope*. If *rel_to_scope* is *None*, return the full pathname. """ return '.'.join((self.parent.get_pathname(rel_to_scope), self.name)) def list_connections(self, is_hidden=False, show_expressions=False): """list all of the inputs and output connections of this PseudoComponent. If is_hidden is True, list the connections that a user would see if this PseudoComponent is hidden. If show_expressions is True (and only if is_hidden is also True) then list the connection expression that resulted in the creation of this PseudoComponent. """ if is_hidden: if self._outdests: if show_expressions: return [(self._orig_src, self._orig_dest)] else: return [(src, self._outdests[0]) for src in self._inmap.keys() if src] else: return [] else: conns = [(src, '.'.join((self.name, dest))) for src, dest in self._inmap.items()] if self._outdests: conns.extend([('.'.join((self.name, 'out0')), dest) for dest in self._outdests]) return conns def list_inputs(self, connected=True): return self._inputs[:] def list_outputs(self, connected=True): return ['out0'] def config_changed(self, update_parent=True): pass def list_comp_connections(self): """Return a list of connections between our pseudocomp and parent components of our sources/destinations. """ conns = [(src.split('.', 1)[0], self.name) for src, dest in self._inmap.items()] if self._outdests: conns.extend([(self.name, dest.split('.', 1)[0]) for dest in self._outdests]) return conns def contains(self, name): return name == 'out0' or name in self._inputs def activate(self, scope, driver=None): scope.add(self.name, self) scope._depgraph.add_component(self.name, self) getattr(scope, self.name).make_connections(scope, driver) def make_connections(self, scope, driver=None): """Connect all of the inputs and outputs of this comp to the appropriate nodes in the dependency graph. """ for src, dest in self.list_connections(): #scope.connect(src, dest) scope._depgraph.connect(scope, src, dest) if driver is not None: scope._depgraph.add_driver_input(driver.name, self.name+'.out0') def run(self, case_uuid=''): if self._negate: setattr(self, 'out0', -self._srcexpr.evaluate()) else: setattr(self, 'out0', self._srcexpr.evaluate()) def evaluate(self): if self._negate: setattr(self, 'out0', -self._srcexpr.evaluate()) else: setattr(self, 'out0', self._srcexpr.evaluate()) def get(self, name): return getattr(self, name) def set(self, path, value): setattr(self, path, value) def get_metadata(self, traitpath, metaname=None): if metaname is None: return self._meta[traitpath] childname, _, restofpath = traitpath.partition('.') if restofpath: return getattr(self, childname).get_metadata(restofpath, metaname) else: return self._meta[traitpath].get(metaname) def set_itername(self, itername): self._itername = itername def linearize(self, first=False, second=False): """Component wrapper for the ProvideJ hook.""" if first: return self.provideJ() if second: msg = "2nd derivatives not supported in pseudocomponent %s" raise RuntimeError(msg % self.name) def provideJ(self): """Calculate analytical first derivatives.""" if self.Jsize is None: n_in = 0 n_out = 0 for varname in self.list_inputs(): val = self.get(varname) width = flattened_size(varname, val, self) n_in += width for varname in self.list_outputs(): val = self.get(varname) width = flattened_size(varname, val, self) n_out += width self.Jsize = (n_out, n_in) J = zeros(self.Jsize) grad = self._srcexpr.evaluate_gradient() i = 0 for varname in self._inputs: val = self.get(varname) width = flattened_size(varname, val, self) J[:, i:i+width] = grad[varname] i += width if self._negate: return -J else: return J def ensure_init(self): """Make sure our inputs and outputs have been initialized. """ # set the current value of the connected variable # into our input for ref, in_name in self._inmap.items(): setattr(self, in_name, ExprEvaluator(ref).evaluate(self.parent)) if has_interface(getattr(self, in_name), IContainer): getattr(self, in_name).name = in_name # set the initial value of the output outval = self._srcexpr.evaluate() setattr(self, 'out0', outval) def list_deriv_vars(self): return tuple(self._inputs), ('out0',) def get_req_cpus(self): return (1, 1) def setup_init(self): self.Jsize = None self._provideJ_bounds = None def init_var_sizes(self): self.ensure_init() def setup_depgraph(self, dgraph): pass def setup_systems(self): return () def setup_communicators(self, comm, scope=None): self.mpi.comm = comm def setup_variables(self): pass def setup_sizes(self): pass def setup_vectors(self, arrays=None): pass def post_setup(self): pass def get_flattened_value(self, path): """Return the named value, which may include an array index, as a flattened array of floats. If the value is not flattenable into an array of floats, raise a TypeError. """ val, idx = get_val_and_index(self, path) return flattened_value(path, val) def set_flattened_value(self, path, value): val,rop = deep_getattr(self, path.split('[',1)[0]) idx = get_index(path) if isinstance(val, int_types): pass # fall through to exception if isinstance(val, complex_or_real_types): if idx is None: setattr(self, path, value[0]) return # else, fall through to error elif isinstance(val, ndarray): if idx is None: setattr(self, path, value) else: val[idx] = value return elif IVariableTree.providedBy(val): raise NotImplementedError("no support for setting flattened values into vartrees") raise TypeError("%s: Failed to set flattened value to variable %s" % (self.name, path)) def get_req_default(self, self_reqired=None): return [] def _input_updated(self, name, fullpath=None): pass def get_full_nodeset(self): """Return the full set of nodes in the depgraph belonging to this component. """ return set((self.name,)) def applyJ(self, system, variables): """ Wrapper for component derivative specification methods. Forward Mode. """ applyJ(system, variables) def applyJT(self, system, variables): """ Wrapper for component derivative specification methods. Adjoint Mode. """ applyJT(system, variables) def raise_exception(self, msg, exception_class=Exception): """Raise an exception.""" coords = '' obj = self while obj is not None: try: coords = obj.get_itername() except AttributeError: try: obj = obj.parent except AttributeError: break else: break if coords: full_msg = '%s (%s): %s' % (self.get_pathname(), coords, msg) else: full_msg = '%s: %s' % (self.get_pathname(), msg) raise exception_class(full_msg)
def __init__(self, parent, srcexpr, destexpr=None, translate=True, pseudo_type=None): if destexpr is None: destexpr = DummyExpr() self.name = _get_new_name() self._inmap = {} # mapping of component vars to our inputs self._meta = {} self._valid = False self._parent = parent self._inputs = [] self.force_fd = False self._provideJ_bounds = None self._pseudo_type = pseudo_type # a string indicating the type of pseudocomp # this is, e.g., 'units', 'constraint', 'objective', # or 'multi_var_expr' self._orig_src = srcexpr.text self._orig_dest = destexpr.text self.Jsize = None varmap = {} rvarmap = {} for i, ref in enumerate(srcexpr.refs()): in_name = 'in%d' % i self._inputs.append(in_name) self._inmap[ref] = in_name varmap[ref] = in_name rvarmap.setdefault(_get_varname(ref), set()).add(ref) setattr(self, in_name, None) refs = list(destexpr.refs()) if refs: if len(refs) == 1: setattr(self, 'out0', None) else: raise RuntimeError( "output of PseudoComponent must reference only one variable" ) varmap[refs[0]] = 'out0' rvarmap.setdefault(_get_varname(refs[0]), set()).add(refs[0]) for name, meta in srcexpr.get_metadata(): for rname in rvarmap[name]: self._meta[varmap[rname]] = meta for name, meta in destexpr.get_metadata(): for rname in rvarmap[name]: self._meta[varmap[rname]] = meta if translate: xformed_src = transform_expression(srcexpr.text, self._inmap) else: xformed_src = srcexpr.text out_units = self._meta['out0'].get('units') pq = None if out_units is not None: # evaluate the src expression using UnitsOnlyPQ objects tmpdict = {} # First, replace values with UnitsOnlyPQ objects for inp in self._inputs: units = self._meta[inp].get('units') if units: tmpdict[inp] = UnitsOnlyPQ(0., units) else: tmpdict[inp] = 0. pq = eval(xformed_src, _expr_dict, tmpdict) self._srcunits = pq.unit unitnode = ast.parse(xformed_src) try: unitxform = unit_xform(unitnode, self._srcunits, out_units) except Exception as err: raise TypeError("Incompatible units for '%s' and '%s': %s" % (srcexpr.text, destexpr.text, err)) unit_src = print_node(unitxform) xformed_src = unit_src else: self._srcunits = None self._srcexpr = ConnectedExprEvaluator(xformed_src, scope=self) # this is just the equation string (for debugging) if self._orig_dest: self._outdests = [self._orig_dest] if pq is None: sunit = dunit = '' else: sunit = "'%s'" % pq.get_unit_name() dunit = "'%s'" % out_units self._orig_expr = "%s %s -> %s %s" % (self._orig_src, sunit, self._orig_dest, dunit) else: self._outdests = [] self._orig_expr = self._orig_src #if destexpr and destexpr.text: #out = destexpr.text #else: #out = 'out0' #if translate: #src = transform_expression(self._srcexpr.text, #_invert_dict(self._inmap)) #else: #src = self._srcexpr.text #self._expr_conn = (src, out) # the actual expression connection self.missing_deriv_policy = 'error'
class PseudoComponent(object): """A 'fake' component that is constructed from an ExprEvaluator. This fake component can be added to a dependency graph and executed along with 'real' components. """ implements(IComponent) def __init__(self, parent, srcexpr, destexpr=None, translate=True, pseudo_type=None): if destexpr is None: destexpr = DummyExpr() self.name = _get_new_name() self._inmap = {} # mapping of component vars to our inputs self._meta = {} self._valid = False self._parent = parent self._inputs = [] self.force_fd = False self._provideJ_bounds = None self._pseudo_type = pseudo_type # a string indicating the type of pseudocomp # this is, e.g., 'units', 'constraint', 'objective', # or 'multi_var_expr' self._orig_src = srcexpr.text self._orig_dest = destexpr.text self.Jsize = None varmap = {} rvarmap = {} for i, ref in enumerate(srcexpr.refs()): in_name = 'in%d' % i self._inputs.append(in_name) self._inmap[ref] = in_name varmap[ref] = in_name rvarmap.setdefault(_get_varname(ref), set()).add(ref) setattr(self, in_name, None) refs = list(destexpr.refs()) if refs: if len(refs) == 1: setattr(self, 'out0', None) else: raise RuntimeError( "output of PseudoComponent must reference only one variable" ) varmap[refs[0]] = 'out0' rvarmap.setdefault(_get_varname(refs[0]), set()).add(refs[0]) for name, meta in srcexpr.get_metadata(): for rname in rvarmap[name]: self._meta[varmap[rname]] = meta for name, meta in destexpr.get_metadata(): for rname in rvarmap[name]: self._meta[varmap[rname]] = meta if translate: xformed_src = transform_expression(srcexpr.text, self._inmap) else: xformed_src = srcexpr.text out_units = self._meta['out0'].get('units') pq = None if out_units is not None: # evaluate the src expression using UnitsOnlyPQ objects tmpdict = {} # First, replace values with UnitsOnlyPQ objects for inp in self._inputs: units = self._meta[inp].get('units') if units: tmpdict[inp] = UnitsOnlyPQ(0., units) else: tmpdict[inp] = 0. pq = eval(xformed_src, _expr_dict, tmpdict) self._srcunits = pq.unit unitnode = ast.parse(xformed_src) try: unitxform = unit_xform(unitnode, self._srcunits, out_units) except Exception as err: raise TypeError("Incompatible units for '%s' and '%s': %s" % (srcexpr.text, destexpr.text, err)) unit_src = print_node(unitxform) xformed_src = unit_src else: self._srcunits = None self._srcexpr = ConnectedExprEvaluator(xformed_src, scope=self) # this is just the equation string (for debugging) if self._orig_dest: self._outdests = [self._orig_dest] if pq is None: sunit = dunit = '' else: sunit = "'%s'" % pq.get_unit_name() dunit = "'%s'" % out_units self._orig_expr = "%s %s -> %s %s" % (self._orig_src, sunit, self._orig_dest, dunit) else: self._outdests = [] self._orig_expr = self._orig_src #if destexpr and destexpr.text: #out = destexpr.text #else: #out = 'out0' #if translate: #src = transform_expression(self._srcexpr.text, #_invert_dict(self._inmap)) #else: #src = self._srcexpr.text #self._expr_conn = (src, out) # the actual expression connection self.missing_deriv_policy = 'error' def check_configuration(self): pass def cpath_updated(self): pass def get_pathname(self, rel_to_scope=None): """ Return full pathname to this object, relative to scope *rel_to_scope*. If *rel_to_scope* is *None*, return the full pathname. """ return '.'.join([self._parent.get_pathname(rel_to_scope), self.name]) def list_connections(self, is_hidden=False, show_expressions=False): """list all of the inputs and output connections of this PseudoComponent. If is_hidden is True, list the connections that a user would see if this PseudoComponent is hidden. If show_expressions is True (and only if is_hidden is also True) then list the connection expression that resulted in the creation of this PseudoComponent. """ if is_hidden: if self._outdests: if show_expressions: return [(self._orig_src, self._orig_dest)] else: return [(src, self._outdests[0]) for src in self._inmap.keys() if src] else: return [] else: conns = [(src, '.'.join([self.name, dest])) for src, dest in self._inmap.items()] if self._outdests: conns.extend([('.'.join([self.name, 'out0']), dest) for dest in self._outdests]) return conns def list_inputs(self): return self._inputs[:] def list_outputs(self): return ['out0'] def list_comp_connections(self): """Return a list of connections between our pseudocomp and parent components of our sources/destinations. """ conns = [(src.split('.', 1)[0], self.name) for src, dest in self._inmap.items()] if self._outdests: conns.extend([(self.name, dest.split('.', 1)[0]) for dest in self._outdests]) return conns def contains(self, name): return name == 'out0' or name in self._inputs def make_connections(self, scope): """Connect all of the inputs and outputs of this comp to the appropriate nodes in the dependency graph. """ for src, dest in self.list_connections(): scope.connect(src, dest) def remove_connections(self, scope): """Disconnect all of the inputs and outputs of this comp from other nodes in the dependency graph. """ for src, dest in self.list_connections(): scope.disconnect(src, dest) def invalidate_deps(self, varnames=None, force=False): self._valid = False return None def get_invalidation_type(self): return 'full' def connect(self, src, dest): self._valid = False def run(self, ffd_order=0, case_id=''): self.update_inputs() src = self._srcexpr.evaluate() setattr(self, 'out0', src) self._valid = True self._parent.child_run_finished(self.name) def update_inputs(self, inputs=None): self._parent.update_inputs(self.name) def update_outputs(self, names): self.run() def get(self, name, index=None): if index is not None: raise RuntimeError("index not supported in PseudoComponent.get") return getattr(self, name) def set(self, path, value, index=None, src=None, force=False): if index is not None: raise ValueError("index not supported in PseudoComponent.set") self.invalidate_deps() setattr(self, path, value) def get_metadata(self, traitpath, metaname=None): if metaname is None: return self._meta[traitpath] return self._meta[traitpath].get(metaname) def is_valid(self): return self._valid def set_itername(self, itername): self._itername = itername def calc_derivatives(self, first=False, second=False, savebase=False, required_inputs=None, required_outputs=None): if first: return self.provideJ() if second: msg = "2nd derivatives not supported in pseudocomponent %s" raise RuntimeError(msg % self.name) def provideJ(self): """Calculate analytical first derivatives.""" if self.Jsize is None: n_in = 0 n_out = 0 for varname in self.list_inputs(): val = self.get(varname) width = flattened_size(varname, val, self) n_in += width for varname in self.list_outputs(): val = self.get(varname) width = flattened_size(varname, val, self) n_out += width self.Jsize = (n_out, n_in) J = zeros(self.Jsize) grad = self._srcexpr.evaluate_gradient() i = 0 for varname in self._inputs: val = self.get(varname) width = flattened_size(varname, val, self) J[:, i:i + width] = grad[varname] i += width return J def list_deriv_vars(self): return tuple(self._inputs), ('out0', )
def __init__(self, parent, srcexpr, destexpr=None, translate=True, pseudo_type=None): if destexpr is None: destexpr = DummyExpr() self.name = _get_new_name() self._inmap = {} # mapping of component vars to our inputs self._meta = {} self._valid = False self._parent = parent self._inputs = [] self.force_fd = False self._provideJ_bounds = None self._pseudo_type = pseudo_type # a string indicating the type of pseudocomp # this is, e.g., 'units', 'constraint', 'objective', # or 'multi_var_expr' self._orig_src = srcexpr.text self._orig_dest = destexpr.text self.Jsize = None if destexpr.text: self._outdests = [destexpr.text] else: self._outdests = [] varmap = {} rvarmap = {} for i,ref in enumerate(srcexpr.refs()): in_name = 'in%d' % i self._inputs.append(in_name) self._inmap[ref] = in_name varmap[ref] = in_name rvarmap.setdefault(_get_varname(ref), set()).add(ref) setattr(self, in_name, None) refs = list(destexpr.refs()) if refs: if len(refs) == 1: setattr(self, 'out0', None) else: raise RuntimeError("output of PseudoComponent must reference only one variable") varmap[refs[0]] = 'out0' rvarmap.setdefault(_get_varname(refs[0]), set()).add(refs[0]) for name, meta in srcexpr.get_metadata(): for rname in rvarmap[name]: self._meta[varmap[rname]] = meta for name, meta in destexpr.get_metadata(): for rname in rvarmap[name]: self._meta[varmap[rname]] = meta if translate: xformed_src = transform_expression(srcexpr.text, self._inmap) else: xformed_src = srcexpr.text out_units = self._meta['out0'].get('units') if out_units is not None: # evaluate the src expression using UnitsOnlyPQ objects tmpdict = {} # First, replace values with UnitsOnlyPQ objects for inp in self._inputs: units = self._meta[inp].get('units') if units: tmpdict[inp] = UnitsOnlyPQ(0., units) else: tmpdict[inp] = 0. pq = eval(xformed_src, _expr_dict, tmpdict) self._srcunits = pq.unit unitnode = ast.parse(xformed_src) try: unitxform = unit_xform(unitnode, self._srcunits, out_units) except Exception as err: raise TypeError("Incompatible units for '%s' and '%s': %s" % (srcexpr.text, destexpr.text, err)) unit_src = print_node(unitxform) xformed_src = unit_src else: self._srcunits = None self._srcexpr = ConnectedExprEvaluator(xformed_src, scope=self) # this is just the equation string (for debugging) if destexpr and destexpr.text: out = destexpr.text else: out = 'out0' if translate: src = transform_expression(self._srcexpr.text, _invert_dict(self._inmap)) else: src = self._srcexpr.text self._expr_conn = (src, out) # the actual expression connection self.missing_deriv_policy = 'error'
class PseudoComponent(object): """A 'fake' component that is constructed from an ExprEvaluator. This fake component can be added to a dependency graph and executed along with 'real' components. """ implements(IComponent) def __init__(self, parent, srcexpr, destexpr=None, translate=True, pseudo_type=None): if destexpr is None: destexpr = DummyExpr() self.name = _get_new_name() self._inmap = {} # mapping of component vars to our inputs self._meta = {} self._valid = False self._parent = parent self._inputs = [] self.force_fd = False self._provideJ_bounds = None self._pseudo_type = pseudo_type # a string indicating the type of pseudocomp # this is, e.g., 'units', 'constraint', 'objective', # or 'multi_var_expr' self._orig_src = srcexpr.text self._orig_dest = destexpr.text self.Jsize = None if destexpr.text: self._outdests = [destexpr.text] else: self._outdests = [] varmap = {} rvarmap = {} for i,ref in enumerate(srcexpr.refs()): in_name = 'in%d' % i self._inputs.append(in_name) self._inmap[ref] = in_name varmap[ref] = in_name rvarmap.setdefault(_get_varname(ref), set()).add(ref) setattr(self, in_name, None) refs = list(destexpr.refs()) if refs: if len(refs) == 1: setattr(self, 'out0', None) else: raise RuntimeError("output of PseudoComponent must reference only one variable") varmap[refs[0]] = 'out0' rvarmap.setdefault(_get_varname(refs[0]), set()).add(refs[0]) for name, meta in srcexpr.get_metadata(): for rname in rvarmap[name]: self._meta[varmap[rname]] = meta for name, meta in destexpr.get_metadata(): for rname in rvarmap[name]: self._meta[varmap[rname]] = meta if translate: xformed_src = transform_expression(srcexpr.text, self._inmap) else: xformed_src = srcexpr.text out_units = self._meta['out0'].get('units') if out_units is not None: # evaluate the src expression using UnitsOnlyPQ objects tmpdict = {} # First, replace values with UnitsOnlyPQ objects for inp in self._inputs: units = self._meta[inp].get('units') if units: tmpdict[inp] = UnitsOnlyPQ(0., units) else: tmpdict[inp] = 0. pq = eval(xformed_src, _expr_dict, tmpdict) self._srcunits = pq.unit unitnode = ast.parse(xformed_src) try: unitxform = unit_xform(unitnode, self._srcunits, out_units) except Exception as err: raise TypeError("Incompatible units for '%s' and '%s': %s" % (srcexpr.text, destexpr.text, err)) unit_src = print_node(unitxform) xformed_src = unit_src else: self._srcunits = None self._srcexpr = ConnectedExprEvaluator(xformed_src, scope=self) # this is just the equation string (for debugging) if destexpr and destexpr.text: out = destexpr.text else: out = 'out0' if translate: src = transform_expression(self._srcexpr.text, _invert_dict(self._inmap)) else: src = self._srcexpr.text self._expr_conn = (src, out) # the actual expression connection self.missing_deriv_policy = 'error' def check_configuration(self): pass def cpath_updated(self): pass def get_pathname(self, rel_to_scope=None): """ Return full pathname to this object, relative to scope *rel_to_scope*. If *rel_to_scope* is *None*, return the full pathname. """ return '.'.join([self._parent.get_pathname(rel_to_scope), self.name]) def list_connections(self, is_hidden=False, show_expressions=False): """list all of the inputs and output connections of this PseudoComponent. If is_hidden is True, list the connections that a user would see if this PseudoComponent is hidden. If show_expressions is True (and only if is_hidden is also True) then list the connection expression that resulted in the creation of this PseudoComponent. """ if is_hidden: if self._outdests: if show_expressions: return [self._expr_conn] else: return [(src, self._outdests[0]) for src in self._inmap.keys() if src] else: return [] else: conns = [(src, '.'.join([self.name, dest])) for src, dest in self._inmap.items()] if self._outdests: conns.extend([('.'.join([self.name, 'out0']), dest) for dest in self._outdests]) return conns def list_inputs(self): return self._inputs[:] def list_outputs(self): return ['out0'] def list_comp_connections(self): """Return a list of connections between our pseudocomp and parent components of our sources/destinations. """ conns = [(src.split('.',1)[0], self.name) for src, dest in self._inmap.items()] if self._outdests: conns.extend([(self.name, dest.split('.',1)[0]) for dest in self._outdests]) return conns def make_connections(self, scope): """Connect all of the inputs and outputs of this comp to the appropriate nodes in the dependency graph. """ for src, dest in self.list_connections(): scope.connect(src, dest) def remove_connections(self, scope): """Disconnect all of the inputs and outputs of this comp from other nodes in the dependency graph. """ for src, dest in self.list_connections(): scope.disconnect(src, dest) def invalidate_deps(self, varnames=None, force=False): self._valid = False return None def get_invalidation_type(self): return 'full' def connect(self, src, dest): self._valid = False def run(self, ffd_order=0, case_id=''): self.update_inputs() src = self._srcexpr.evaluate() setattr(self, 'out0', src) self._valid = True self._parent.child_run_finished(self.name) def update_inputs(self, inputs=None): self._parent.update_inputs(self.name) def update_outputs(self, names): self.run() def get(self, name, index=None): if index is not None: raise RuntimeError("index not supported in PseudoComponent.get") return getattr(self, name) def set(self, path, value, index=None, src=None, force=False): if index is not None: raise ValueError("index not supported in PseudoComponent.set") self.invalidate_deps() setattr(self, path, value) def get_metadata(self, traitpath, metaname=None): if metaname is None: return self._meta[traitpath] return self._meta[traitpath].get(metaname) def is_valid(self): return self._valid def set_itername(self, itername): self._itername = itername def calc_derivatives(self, first=False, second=False, savebase=False, required_inputs=None, required_outputs=None): if first: return self.provideJ() if second: msg = "2nd derivatives not supported in pseudocomponent %s" raise RuntimeError(msg % self.name) def provideJ(self): """Calculate analytical first derivatives.""" if self.Jsize is None: n_in = 0 n_out = 0 for varname in self.list_inputs(): val = self.get(varname) width = flattened_size(varname, val, self) n_in += width for varname in self.list_outputs(): val = self.get(varname) width = flattened_size(varname, val, self) n_out += width self.Jsize = (n_out, n_in) J = zeros(self.Jsize) grad = self._srcexpr.evaluate_gradient() i = 0 for varname in self._inputs: val = self.get(varname) width = flattened_size(varname, val, self) J[:, i:i+width] = grad[varname] i += width return J def list_deriv_vars(self): return tuple(self._inputs), ('out0',)
def __init__(self, parent, srcexpr, destexpr=None, translate=True, pseudo_type=None, subtype=None, exprobject=None): if destexpr is None: destexpr = DummyExpr() self._parent = None self.parent = parent self.name = _get_new_name(parent) self._inmap = {} # mapping of component vars to our inputs self._meta = {} self._inputs = [] # Flags and caching used by the derivatives calculation self.force_fd = False self._provideJ_bounds = None self._pseudo_type = pseudo_type # a string indicating the type of pseudocomp # this is, e.g., 'units', 'constraint', 'objective', # or 'multi_var_expr' self._subtype = subtype # for constraints, 'equality' or 'inequality' self._exprobj = exprobject # object responsible for creation of this pcomp, e.g. a Constraint self._orig_src = srcexpr.text self._orig_dest = destexpr.text self.Jsize = None self.mpi = MPI_info() varmap = {} rvarmap = {} for i, ref in enumerate(srcexpr.ordered_refs()): in_name = 'in%d' % i self._inputs.append(in_name) self._inmap[ref] = in_name varmap[ref] = in_name rvarmap.setdefault(_get_varname(ref), set()).add(ref) setattr(self, in_name, 0.) refs = list(destexpr.refs()) if refs: if len(refs) == 1: setattr(self, 'out0', None) else: raise RuntimeError("output of PseudoComponent must reference" " only one variable") varmap[refs[0]] = 'out0' rvarmap.setdefault(_get_varname(refs[0]), set()).add(refs[0]) noflat = False for name, meta in srcexpr.get_metadata(): # If any input is noflat, then the output must be too. if 'noflat' in meta: noflat = True for rname in rvarmap[name]: self._meta[varmap[rname]] = meta for name, meta in destexpr.get_metadata(): if noflat and 'noflat' not in meta: meta['noflat'] = True for rname in rvarmap[name]: self._meta[varmap[rname]] = meta if translate: xformed_src = transform_expression(srcexpr.text, self._inmap) else: xformed_src = srcexpr.text out_units = self._meta['out0'].get('units') pq = None if out_units is not None: # evaluate the src expression using UnitsOnlyPQ objects tmpdict = {} # First, replace values with UnitsOnlyPQ objects for inp in self._inputs: units = self._meta[inp].get('units') if units: tmpdict[inp] = UnitsOnlyPQ(0., units) else: tmpdict[inp] = 0. pq = eval(xformed_src, _expr_dict, tmpdict) self._srcunits = pq.unit unitnode = ast.parse(xformed_src) try: unitxform = unit_xform(unitnode, self._srcunits, out_units) except Exception as err: raise TypeError("Incompatible units for '%s' and '%s': %s" % (srcexpr.text, destexpr.text, err)) unit_src = print_node(unitxform) xformed_src = unit_src else: self._srcunits = None self._srcexpr = ConnectedExprEvaluator(xformed_src, scope=self) # this is just the equation string (for debugging) if self._orig_dest: self._outdests = [self._orig_dest] if pq is None: sunit = dunit = '' else: sunit = "'%s'" % pq.get_unit_name() dunit = "'%s'" % out_units self._orig_expr = "%s %s -> %s %s" % (self._orig_src, sunit, self._orig_dest, dunit) else: self._outdests = [] self._orig_expr = self._orig_src self.missing_deriv_policy = 'error' self._negate = False
class PseudoComponent(object): """A 'fake' component that is constructed from an ExprEvaluator. This fake component can be added to a dependency graph and executed along with 'real' components. """ implements(IComponent, IPseudoComp) def __init__(self, parent, srcexpr, destexpr=None, translate=True, pseudo_type=None, subtype=None, exprobject=None): if destexpr is None: destexpr = DummyExpr() self._parent = None self.parent = parent self.name = _get_new_name(parent) self._inmap = {} # mapping of component vars to our inputs self._meta = {} self._inputs = [] self._initialized = False # Flags and caching used by the derivatives calculation self.force_fd = False self._provideJ_bounds = None self._pseudo_type = pseudo_type # a string indicating the type of pseudocomp # this is, e.g., 'units', 'constraint', 'objective', # or 'multi_var_expr' self._subtype = subtype # for constraints, 'equality' or 'inequality' self._exprobj = exprobject # object responsible for creation of this pcomp, e.g. a Constraint self._orig_src = srcexpr.text self._orig_dest = destexpr.text self.Jsize = None self.mpi = MPI_info() varmap = {} rvarmap = {} for i, ref in enumerate(srcexpr.ordered_refs()): in_name = 'in%d' % i self._inputs.append(in_name) self._inmap[ref] = in_name varmap[ref] = in_name rvarmap.setdefault(_get_varname(ref), set()).add(ref) setattr(self, in_name, 0.) refs = list(destexpr.refs()) if refs: if len(refs) == 1: setattr(self, 'out0', None) else: raise RuntimeError("output of PseudoComponent must reference" " only one variable") varmap[refs[0]] = 'out0' rvarmap.setdefault(_get_varname(refs[0]), set()).add(refs[0]) for name, meta in srcexpr.get_metadata(): for rname in rvarmap[name]: self._meta[varmap[rname]] = meta for name, meta in destexpr.get_metadata(): for rname in rvarmap[name]: self._meta[varmap[rname]] = meta if translate: xformed_src = transform_expression(srcexpr.text, self._inmap) else: xformed_src = srcexpr.text out_units = self._meta['out0'].get('units') pq = None if out_units is not None: # evaluate the src expression using UnitsOnlyPQ objects tmpdict = {} # First, replace values with UnitsOnlyPQ objects for inp in self._inputs: units = self._meta[inp].get('units') if units: tmpdict[inp] = UnitsOnlyPQ(0., units) else: tmpdict[inp] = 0. pq = eval(xformed_src, _expr_dict, tmpdict) self._srcunits = pq.unit unitnode = ast.parse(xformed_src) try: unitxform = unit_xform(unitnode, self._srcunits, out_units) except Exception as err: raise TypeError("Incompatible units for '%s' and '%s': %s" % (srcexpr.text, destexpr.text, err)) unit_src = print_node(unitxform) xformed_src = unit_src else: self._srcunits = None self._srcexpr = ConnectedExprEvaluator(xformed_src, scope=self) # this is just the equation string (for debugging) if self._orig_dest: self._outdests = [self._orig_dest] if pq is None: sunit = dunit = '' else: sunit = "'%s'" % pq.get_unit_name() dunit = "'%s'" % out_units self._orig_expr = "%s %s -> %s %s" % (self._orig_src, sunit, self._orig_dest, dunit) else: self._outdests = [] self._orig_expr = self._orig_src self.missing_deriv_policy = 'error' self._negate = False def __getstate__(self): state = self.__dict__.copy() state['_parent'] = self.parent return state def __setstate__(self, state): self.__dict__.update(state) self.parent = state['_parent'] @property def parent(self): """ Our parent assembly. """ return None if self._parent is None else self._parent() @parent.setter def parent(self, parent): self._parent = None if parent is None else weakref.ref(parent) def check_config(self, strict=False): pass def cpath_updated(self): pass def is_differentiable(self): """Return True if analytical derivatives can be computed for this Component. """ return True def get_pathname(self, rel_to_scope=None): """ Return full pathname to this object, relative to scope *rel_to_scope*. If *rel_to_scope* is *None*, return the full pathname. """ return '.'.join((self.parent.get_pathname(rel_to_scope), self.name)) def list_connections(self, is_hidden=False, show_expressions=False): """list all of the inputs and output connections of this PseudoComponent. If is_hidden is True, list the connections that a user would see if this PseudoComponent is hidden. If show_expressions is True (and only if is_hidden is also True) then list the connection expression that resulted in the creation of this PseudoComponent. """ if is_hidden: if self._outdests: if show_expressions: return [(self._orig_src, self._orig_dest)] else: return [(src, self._outdests[0]) for src in self._inmap.keys() if src] else: return [] else: conns = [(src, '.'.join((self.name, dest))) for src, dest in self._inmap.items()] if self._outdests: conns.extend([('.'.join((self.name, 'out0')), dest) for dest in self._outdests]) return conns def list_inputs(self, connected=True): return self._inputs[:] def list_outputs(self, connected=True): return ['out0'] def config_changed(self, update_parent=True): pass def list_comp_connections(self): """Return a list of connections between our pseudocomp and parent components of our sources/destinations. """ conns = [(src.split('.', 1)[0], self.name) for src, dest in self._inmap.items()] if self._outdests: conns.extend([(self.name, dest.split('.', 1)[0]) for dest in self._outdests]) return conns def contains(self, name): return name == 'out0' or name in self._inputs def make_connections(self, scope, driver=None): """Connect all of the inputs and outputs of this comp to the appropriate nodes in the dependency graph. """ for src, dest in self.list_connections(): scope.connect(src, dest) if driver is not None: scope._depgraph.add_driver_input(driver.name, self.name + '.out0') def run(self, case_uuid=''): if self._negate: setattr(self, 'out0', -self._srcexpr.evaluate()) else: setattr(self, 'out0', self._srcexpr.evaluate()) def evaluate(self): if self._negate: setattr(self, 'out0', -self._srcexpr.evaluate()) else: setattr(self, 'out0', self._srcexpr.evaluate()) def get(self, name): return getattr(self, name) def set(self, path, value): setattr(self, path, value) def get_metadata(self, traitpath, metaname=None): if metaname is None: return self._meta[traitpath] childname, _, restofpath = traitpath.partition('.') if restofpath: return getattr(self, childname).get_metadata(restofpath, metaname) else: return self._meta[traitpath].get(metaname) def set_itername(self, itername): self._itername = itername def linearize(self, first=False, second=False): """Component wrapper for the ProvideJ hook.""" if first: return self.provideJ() if second: msg = "2nd derivatives not supported in pseudocomponent %s" raise RuntimeError(msg % self.name) def provideJ(self): """Calculate analytical first derivatives.""" if self.Jsize is None: n_in = 0 n_out = 0 for varname in self.list_inputs(): val = self.get(varname) width = flattened_size(varname, val, self) n_in += width for varname in self.list_outputs(): val = self.get(varname) width = flattened_size(varname, val, self) n_out += width self.Jsize = (n_out, n_in) J = zeros(self.Jsize) grad = self._srcexpr.evaluate_gradient() i = 0 for varname in self._inputs: val = self.get(varname) width = flattened_size(varname, val, self) J[:, i:i + width] = grad[varname] i += width if self._negate: return -J else: return J def ensure_init(self): """Make sure our inputs and outputs have been initialized. """ if not self._initialized: # set the current value of the connected variable # into our input for ref, in_name in self._inmap.items(): setattr(self, in_name, ExprEvaluator(ref).evaluate(self.parent)) if has_interface(getattr(self, in_name), IContainer): getattr(self, in_name).name = in_name # set the initial value of the output setattr(self, 'out0', self._srcexpr.evaluate()) self._initialized = True def list_deriv_vars(self): return tuple(self._inputs), ('out0', ) def get_req_cpus(self): return 1 def pre_setup(self): self.ensure_init() def setup_depgraph(self, dgraph): pass def setup_systems(self): return () def setup_communicators(self, comm, scope=None): self.mpi.comm = comm def setup_variables(self): pass def setup_sizes(self): pass def setup_vectors(self, arrays=None): pass def post_setup(self): pass def get_flattened_value(self, path): """Return the named value, which may include an array index, as a flattened array of floats. If the value is not flattenable into an array of floats, raise a TypeError. """ self.ensure_init() val, idx = get_val_and_index(self, path) return flattened_value(path, val) def set_flattened_value(self, path, value): self.ensure_init() val, rop = deep_getattr(self, path.split('[', 1)[0]) idx = get_index(path) if isinstance(val, int_types): pass # fall through to exception if isinstance(val, complex_or_real_types): if idx is None: setattr(self, path, value[0]) return # else, fall through to error elif isinstance(val, ndarray): if idx is None: setattr(self, path, value) else: val[idx] = value return elif IVariableTree.providedBy(val): raise NotImplementedError( "no support for setting flattened values into vartrees") raise TypeError("%s: Failed to set flattened value to variable %s" % (self.name, path)) def get_req_default(self, self_reqired=None): return [] def _input_updated(self, name, fullpath=None): pass def get_full_nodeset(self): """Return the full set of nodes in the depgraph belonging to this component. """ return set((self.name, )) def applyJ(self, system, variables): """ Wrapper for component derivative specification methods. Forward Mode. """ applyJ(system, variables) def applyJT(self, system, variables): """ Wrapper for component derivative specification methods. Adjoint Mode. """ applyJT(system, variables) def raise_exception(self, msg, exception_class=Exception): """Raise an exception.""" coords = '' obj = self while obj is not None: try: coords = obj.get_itername() except AttributeError: try: obj = obj.parent except AttributeError: break else: break if coords: full_msg = '%s (%s): %s' % (self.get_pathname(), coords, msg) else: full_msg = '%s: %s' % (self.get_pathname(), msg) raise exception_class(full_msg)