Example #1
class PseudoComponent(object):
    """A 'fake' component that is constructed from an ExprEvaluator.
    This fake component can be added to a dependency graph and executed
    along with 'real' components.

    def __init__(self,
        if destexpr is None:
            destexpr = DummyExpr()
        self.name = _get_new_name()
        self._inmap = {}  # mapping of component vars to our inputs
        self._meta = {}
        self._valid = False
        self._parent = parent
        self._inputs = []
        self.force_fd = False
        self._provideJ_bounds = None
        self._pseudo_type = pseudo_type  # a string indicating the type of pseudocomp
        # this is, e.g., 'units', 'constraint', 'objective',
        # or 'multi_var_expr'
        self._orig_src = srcexpr.text
        self._orig_dest = destexpr.text
        self.Jsize = None

        varmap = {}
        rvarmap = {}
        for i, ref in enumerate(srcexpr.refs()):
            in_name = 'in%d' % i
            self._inmap[ref] = in_name
            varmap[ref] = in_name
            rvarmap.setdefault(_get_varname(ref), set()).add(ref)
            setattr(self, in_name, None)

        refs = list(destexpr.refs())
        if refs:
            if len(refs) == 1:
                setattr(self, 'out0', None)
                raise RuntimeError(
                    "output of PseudoComponent must reference only one variable"
        varmap[refs[0]] = 'out0'
        rvarmap.setdefault(_get_varname(refs[0]), set()).add(refs[0])

        for name, meta in srcexpr.get_metadata():
            for rname in rvarmap[name]:
                self._meta[varmap[rname]] = meta

        for name, meta in destexpr.get_metadata():
            for rname in rvarmap[name]:
                self._meta[varmap[rname]] = meta

        if translate:
            xformed_src = transform_expression(srcexpr.text, self._inmap)
            xformed_src = srcexpr.text

        out_units = self._meta['out0'].get('units')
        pq = None
        if out_units is not None:
            # evaluate the src expression using UnitsOnlyPQ objects

            tmpdict = {}

            # First, replace values with UnitsOnlyPQ objects
            for inp in self._inputs:
                units = self._meta[inp].get('units')
                if units:
                    tmpdict[inp] = UnitsOnlyPQ(0., units)
                    tmpdict[inp] = 0.

            pq = eval(xformed_src, _expr_dict, tmpdict)
            self._srcunits = pq.unit

            unitnode = ast.parse(xformed_src)
                unitxform = unit_xform(unitnode, self._srcunits, out_units)
            except Exception as err:
                raise TypeError("Incompatible units for '%s' and '%s': %s" %
                                (srcexpr.text, destexpr.text, err))
            unit_src = print_node(unitxform)
            xformed_src = unit_src
            self._srcunits = None

        self._srcexpr = ConnectedExprEvaluator(xformed_src, scope=self)

        # this is just the equation string (for debugging)
        if self._orig_dest:
            self._outdests = [self._orig_dest]
            if pq is None:
                sunit = dunit = ''
                sunit = "'%s'" % pq.get_unit_name()
                dunit = "'%s'" % out_units
            self._orig_expr = "%s %s -> %s %s" % (self._orig_src, sunit,
                                                  self._orig_dest, dunit)
            self._outdests = []
            self._orig_expr = self._orig_src

        #if destexpr and destexpr.text:
        #out = destexpr.text
        #out = 'out0'
        #if translate:
        #src = transform_expression(self._srcexpr.text,
        #src = self._srcexpr.text

        #self._expr_conn = (src, out)  # the actual expression connection

        self.missing_deriv_policy = 'error'

    def check_configuration(self):

    def cpath_updated(self):

    def get_pathname(self, rel_to_scope=None):
        """ Return full pathname to this object, relative to scope
        *rel_to_scope*. If *rel_to_scope* is *None*, return the full pathname.
        return '.'.join([self._parent.get_pathname(rel_to_scope), self.name])

    def list_connections(self, is_hidden=False, show_expressions=False):
        """list all of the inputs and output connections of this PseudoComponent.
        If is_hidden is True, list the connections that a user would see
        if this PseudoComponent is hidden. If show_expressions is True (and
        only if is_hidden is also True) then list the connection expression
        that resulted in the creation of this PseudoComponent.
        if is_hidden:
            if self._outdests:
                if show_expressions:
                    return [(self._orig_src, self._orig_dest)]
                    return [(src, self._outdests[0])
                            for src in self._inmap.keys() if src]
                return []
            conns = [(src, '.'.join([self.name, dest]))
                     for src, dest in self._inmap.items()]
            if self._outdests:
                conns.extend([('.'.join([self.name, 'out0']), dest)
                              for dest in self._outdests])
        return conns

    def list_inputs(self):
        return self._inputs[:]

    def list_outputs(self):
        return ['out0']

    def list_comp_connections(self):
        """Return a list of connections between our pseudocomp and
        parent components of our sources/destinations.
        conns = [(src.split('.', 1)[0], self.name)
                 for src, dest in self._inmap.items()]
        if self._outdests:
            conns.extend([(self.name, dest.split('.', 1)[0])
                          for dest in self._outdests])
        return conns

    def contains(self, name):
        return name == 'out0' or name in self._inputs

    def make_connections(self, scope):
        """Connect all of the inputs and outputs of this comp to
        the appropriate nodes in the dependency graph.
        for src, dest in self.list_connections():
            scope.connect(src, dest)

    def remove_connections(self, scope):
        """Disconnect all of the inputs and outputs of this comp
        from other nodes in the dependency graph.
        for src, dest in self.list_connections():
            scope.disconnect(src, dest)

    def invalidate_deps(self, varnames=None, force=False):
        self._valid = False
        return None

    def get_invalidation_type(self):
        return 'full'

    def connect(self, src, dest):
        self._valid = False

    def run(self, ffd_order=0, case_id=''):

        src = self._srcexpr.evaluate()
        setattr(self, 'out0', src)
        self._valid = True

    def update_inputs(self, inputs=None):

    def update_outputs(self, names):

    def get(self, name, index=None):
        if index is not None:
            raise RuntimeError("index not supported in PseudoComponent.get")
        return getattr(self, name)

    def set(self, path, value, index=None, src=None, force=False):
        if index is not None:
            raise ValueError("index not supported in PseudoComponent.set")
        setattr(self, path, value)

    def get_metadata(self, traitpath, metaname=None):
        if metaname is None:
            return self._meta[traitpath]
        return self._meta[traitpath].get(metaname)

    def is_valid(self):
        return self._valid

    def set_itername(self, itername):
        self._itername = itername

    def calc_derivatives(self,
        if first:
            return self.provideJ()
        if second:
            msg = "2nd derivatives not supported in pseudocomponent %s"
            raise RuntimeError(msg % self.name)

    def provideJ(self):
        """Calculate analytical first derivatives."""

        if self.Jsize is None:
            n_in = 0
            n_out = 0
            for varname in self.list_inputs():
                val = self.get(varname)
                width = flattened_size(varname, val, self)
                n_in += width
            for varname in self.list_outputs():
                val = self.get(varname)
                width = flattened_size(varname, val, self)
                n_out += width
            self.Jsize = (n_out, n_in)

        J = zeros(self.Jsize)
        grad = self._srcexpr.evaluate_gradient()

        i = 0
        for varname in self._inputs:
            val = self.get(varname)
            width = flattened_size(varname, val, self)
            J[:, i:i + width] = grad[varname]
            i += width

        return J

    def list_deriv_vars(self):
        return tuple(self._inputs), ('out0', )
class PseudoComponent(object):
    """A 'fake' component that is constructed from an ExprEvaluator.
    This fake component can be added to a dependency graph and executed
    along with 'real' components.
    implements(IComponent, IPseudoComp)

    def __init__(self, parent, srcexpr, destexpr=None,
                 translate=True, pseudo_type=None, subtype=None,
        if destexpr is None:
            destexpr = DummyExpr()
        self._parent = None
        self.parent = parent
        self.name = _get_new_name(parent)
        self._inmap = {}  # mapping of component vars to our inputs
        self._meta = {}
        self._inputs = []

        # Flags and caching used by the derivatives calculation
        self.force_fd = False
        self._provideJ_bounds = None

        self._pseudo_type = pseudo_type  # a string indicating the type of pseudocomp
                                         # this is, e.g., 'units', 'constraint', 'objective',
                                         # or 'multi_var_expr'
        self._subtype = subtype  # for constraints, 'equality' or 'inequality'

        self._exprobj = exprobject  # object responsible for creation of this pcomp, e.g. a Constraint

        self._orig_src = srcexpr.text
        self._orig_dest = destexpr.text
        self.Jsize = None
        self.mpi = MPI_info()

        varmap = {}
        rvarmap = {}
        for i, ref in enumerate(srcexpr.ordered_refs()):
            in_name = 'in%d' % i
            self._inmap[ref] = in_name
            varmap[ref] = in_name
            rvarmap.setdefault(_get_varname(ref), set()).add(ref)
            setattr(self, in_name, 0.)

        refs = list(destexpr.refs())
        if refs:
            if len(refs) == 1:
                setattr(self, 'out0', None)
                raise RuntimeError("output of PseudoComponent must reference"
                                   " only one variable")
        varmap[refs[0]] = 'out0'
        rvarmap.setdefault(_get_varname(refs[0]), set()).add(refs[0])

        noflat = False
        for name, meta in srcexpr.get_metadata():

            # If any input is noflat, then the output must be too.
            if 'noflat' in meta:
                noflat = True
            for rname in rvarmap[name]:
                self._meta[varmap[rname]] = meta

        for name, meta in destexpr.get_metadata():
            if noflat and 'noflat' not in meta:
                meta['noflat'] = True
            for rname in rvarmap[name]:
                self._meta[varmap[rname]] = meta

        if translate:
            xformed_src = transform_expression(srcexpr.text, self._inmap)
            xformed_src = srcexpr.text

        out_units = self._meta['out0'].get('units')
        pq = None
        if out_units is not None:
            # evaluate the src expression using UnitsOnlyPQ objects

            tmpdict = {}

            # First, replace values with UnitsOnlyPQ objects
            for inp in self._inputs:
                units = self._meta[inp].get('units')
                if units:
                    tmpdict[inp] = UnitsOnlyPQ(0., units)
                    tmpdict[inp] = 0.

            pq = eval(xformed_src, _expr_dict, tmpdict)
            self._srcunits = pq.unit

            unitnode = ast.parse(xformed_src)
                unitxform = unit_xform(unitnode, self._srcunits, out_units)
            except Exception as err:
                raise TypeError("Incompatible units for '%s' and '%s': %s"
                                % (srcexpr.text, destexpr.text, err))
            unit_src = print_node(unitxform)
            xformed_src = unit_src
            self._srcunits = None

        self._srcexpr = ConnectedExprEvaluator(xformed_src,

        # this is just the equation string (for debugging)
        if self._orig_dest:
            self._outdests = [self._orig_dest]
            if pq is None:
                sunit = dunit = ''
                sunit = "'%s'" % pq.get_unit_name()
                dunit = "'%s'" % out_units
            self._orig_expr = "%s %s -> %s %s" % (self._orig_src, sunit,
                                                  self._orig_dest, dunit)
            self._outdests = []
            self._orig_expr = self._orig_src

        self.missing_deriv_policy = 'error'
        self._negate = False

    def __getstate__(self):
        state = self.__dict__.copy()
        state['_parent'] = self.parent
        return state

    def __setstate__(self, state):
        self.parent = state['_parent']

    def parent(self):
        """ Our parent assembly. """
        return None if self._parent is None else self._parent()

    def parent(self, parent):
        self._parent = None if parent is None else weakref.ref(parent)

    def check_config(self, strict=False):

    def cpath_updated(self):

    def is_differentiable(self):
        """Return True if analytical derivatives can be
        computed for this Component.
        return True

    def get_pathname(self, rel_to_scope=None):
        """ Return full pathname to this object, relative to scope
        *rel_to_scope*. If *rel_to_scope* is *None*, return the full pathname.
        return '.'.join((self.parent.get_pathname(rel_to_scope), self.name))

    def list_connections(self, is_hidden=False, show_expressions=False):
        """list all of the inputs and output connections of this PseudoComponent.
        If is_hidden is True, list the connections that a user would see
        if this PseudoComponent is hidden. If show_expressions is True (and
        only if is_hidden is also True) then list the connection expression
        that resulted in the creation of this PseudoComponent.
        if is_hidden:
            if self._outdests:
                if show_expressions:
                    return [(self._orig_src, self._orig_dest)]
                    return [(src, self._outdests[0])
                            for src in self._inmap.keys() if src]
                return []
            conns = [(src, '.'.join((self.name, dest)))
                     for src, dest in self._inmap.items()]
            if self._outdests:
                conns.extend([('.'.join((self.name, 'out0')), dest)
                              for dest in self._outdests])
        return conns

    def list_inputs(self, connected=True):
        return self._inputs[:]

    def list_outputs(self, connected=True):
        return ['out0']

    def config_changed(self, update_parent=True):

    def list_comp_connections(self):
        """Return a list of connections between our pseudocomp and
        parent components of our sources/destinations.
        conns = [(src.split('.', 1)[0], self.name)
                 for src, dest in self._inmap.items()]
        if self._outdests:
            conns.extend([(self.name, dest.split('.', 1)[0])
                          for dest in self._outdests])
        return conns

    def contains(self, name):
        return name == 'out0' or name in self._inputs

    def activate(self, scope, driver=None):
        scope.add(self.name, self)
        scope._depgraph.add_component(self.name, self)
        getattr(scope, self.name).make_connections(scope, driver)

    def make_connections(self, scope, driver=None):
        """Connect all of the inputs and outputs of this comp to
        the appropriate nodes in the dependency graph.
        for src, dest in self.list_connections():
            #scope.connect(src, dest)
            scope._depgraph.connect(scope, src, dest)

        if driver is not None:

    def run(self, case_uuid=''):
        if self._negate:
            setattr(self, 'out0', -self._srcexpr.evaluate())
            setattr(self, 'out0', self._srcexpr.evaluate())

    def evaluate(self):
        if self._negate:
            setattr(self, 'out0', -self._srcexpr.evaluate())
            setattr(self, 'out0', self._srcexpr.evaluate())

    def get(self, name):
        return getattr(self, name)

    def set(self, path, value):
        setattr(self, path, value)

    def get_metadata(self, traitpath, metaname=None):
        if metaname is None:
            return self._meta[traitpath]
        childname, _, restofpath = traitpath.partition('.')
        if restofpath:
            return getattr(self, childname).get_metadata(restofpath, metaname)
            return self._meta[traitpath].get(metaname)

    def set_itername(self, itername):
        self._itername = itername

    def linearize(self, first=False, second=False):
        """Component wrapper for the ProvideJ hook."""
        if first:
            return self.provideJ()
        if second:
            msg = "2nd derivatives not supported in pseudocomponent %s"
            raise RuntimeError(msg % self.name)

    def provideJ(self):
        """Calculate analytical first derivatives."""

        if self.Jsize is None:
            n_in = 0
            n_out = 0
            for varname in self.list_inputs():
                val = self.get(varname)
                width = flattened_size(varname, val, self)
                n_in += width
            for varname in self.list_outputs():
                val = self.get(varname)
                width = flattened_size(varname, val, self)
                n_out += width
            self.Jsize = (n_out, n_in)

        J = zeros(self.Jsize)
        grad = self._srcexpr.evaluate_gradient()

        i = 0
        for varname in self._inputs:
            val = self.get(varname)
            width = flattened_size(varname, val, self)
            J[:, i:i+width] = grad[varname]
            i += width

        if self._negate:
            return -J
            return J

    def ensure_init(self):
        """Make sure our inputs and outputs have been
        # set the current value of the connected variable
        # into our input
        for ref, in_name in self._inmap.items():
            setattr(self, in_name,
            if has_interface(getattr(self, in_name), IContainer):
                getattr(self, in_name).name = in_name

        # set the initial value of the output
        outval = self._srcexpr.evaluate()
        setattr(self, 'out0', outval)

    def list_deriv_vars(self):
        return tuple(self._inputs), ('out0',)

    def get_req_cpus(self):
        return (1, 1)

    def setup_init(self):
        self.Jsize = None
        self._provideJ_bounds = None

    def init_var_sizes(self):

    def setup_depgraph(self, dgraph):

    def setup_systems(self):
        return ()

    def setup_communicators(self, comm, scope=None):
        self.mpi.comm = comm

    def setup_variables(self):

    def setup_sizes(self):

    def setup_vectors(self, arrays=None):

    def post_setup(self):

    def get_flattened_value(self, path):
        """Return the named value, which may include
        an array index, as a flattened array of floats.  If
        the value is not flattenable into an array of floats,
        raise a TypeError.
        val, idx = get_val_and_index(self, path)
        return flattened_value(path, val)

    def set_flattened_value(self, path, value):
        val,rop = deep_getattr(self, path.split('[',1)[0])
        idx = get_index(path)
        if isinstance(val, int_types):
            pass  # fall through to exception
        if isinstance(val, complex_or_real_types):
            if idx is None:
                setattr(self, path, value[0])
            # else, fall through to error
        elif isinstance(val, ndarray):
            if idx is None:
                setattr(self, path, value)
                val[idx] = value
        elif IVariableTree.providedBy(val):
            raise NotImplementedError("no support for setting flattened values into vartrees")

        raise TypeError("%s: Failed to set flattened value to variable %s" % (self.name, path))

    def get_req_default(self, self_reqired=None):
        return []

    def _input_updated(self, name, fullpath=None):

    def get_full_nodeset(self):
        """Return the full set of nodes in the depgraph
        belonging to this component.
        return set((self.name,))

    def applyJ(self, system, variables):
        """ Wrapper for component derivative specification methods.
        Forward Mode.
        applyJ(system, variables)

    def applyJT(self, system, variables):
        """ Wrapper for component derivative specification methods.
        Adjoint Mode.
        applyJT(system, variables)

    def raise_exception(self, msg, exception_class=Exception):
        """Raise an exception."""
        coords = ''
        obj = self
        while obj is not None:
                coords = obj.get_itername()
            except AttributeError:
                    obj = obj.parent
                except AttributeError:
        if coords:
            full_msg = '%s (%s): %s' % (self.get_pathname(), coords, msg)
            full_msg = '%s: %s' % (self.get_pathname(), msg)
        raise exception_class(full_msg)
class PseudoComponent(object):
    """A 'fake' component that is constructed from an ExprEvaluator.
    This fake component can be added to a dependency graph and executed
    along with 'real' components.

    def __init__(self, parent, srcexpr, destexpr=None, translate=True, pseudo_type=None):
        if destexpr is None:
            destexpr = DummyExpr()
        self.name = _get_new_name()
        self._inmap = {} # mapping of component vars to our inputs
        self._meta = {}
        self._valid = False
        self._parent = parent
        self._inputs = []
        self.force_fd = False
        self._provideJ_bounds = None
        self._pseudo_type = pseudo_type # a string indicating the type of pseudocomp
                                        # this is, e.g., 'units', 'constraint', 'objective',
                                        # or 'multi_var_expr'
        self._orig_src = srcexpr.text
        self._orig_dest = destexpr.text
        self.Jsize = None

        if destexpr.text:
            self._outdests = [destexpr.text]
            self._outdests = []

        varmap = {}
        rvarmap = {}
        for i,ref in enumerate(srcexpr.refs()):
            in_name = 'in%d' % i
            self._inmap[ref] = in_name
            varmap[ref] = in_name
            rvarmap.setdefault(_get_varname(ref), set()).add(ref)
            setattr(self, in_name, None)

        refs = list(destexpr.refs())
        if refs:
            if len(refs) == 1:
                setattr(self, 'out0', None)
                raise RuntimeError("output of PseudoComponent must reference only one variable")
        varmap[refs[0]] = 'out0'
        rvarmap.setdefault(_get_varname(refs[0]), set()).add(refs[0])

        for name, meta in srcexpr.get_metadata():
            for rname in rvarmap[name]:
                self._meta[varmap[rname]] = meta

        for name, meta in destexpr.get_metadata():
            for rname in rvarmap[name]:
                self._meta[varmap[rname]] = meta

        if translate:
            xformed_src = transform_expression(srcexpr.text, self._inmap)
            xformed_src = srcexpr.text

        out_units = self._meta['out0'].get('units')
        if out_units is not None:
            # evaluate the src expression using UnitsOnlyPQ objects

            tmpdict = {}

            # First, replace values with UnitsOnlyPQ objects
            for inp in self._inputs:
                units = self._meta[inp].get('units')
                if units:
                    tmpdict[inp] = UnitsOnlyPQ(0., units)
                    tmpdict[inp] = 0.

            pq = eval(xformed_src, _expr_dict, tmpdict)
            self._srcunits = pq.unit

            unitnode = ast.parse(xformed_src)
                unitxform = unit_xform(unitnode, self._srcunits, out_units)
            except Exception as err:
                raise TypeError("Incompatible units for '%s' and '%s': %s" % (srcexpr.text,
                                                                    destexpr.text, err))
            unit_src = print_node(unitxform)
            xformed_src = unit_src
            self._srcunits = None

        self._srcexpr = ConnectedExprEvaluator(xformed_src, scope=self)

        # this is just the equation string (for debugging)
        if destexpr and destexpr.text:
            out = destexpr.text
            out = 'out0'
        if translate:
            src = transform_expression(self._srcexpr.text,
            src = self._srcexpr.text

        self._expr_conn = (src, out)  # the actual expression connection

        self.missing_deriv_policy = 'error'

    def check_configuration(self):

    def cpath_updated(self):

    def get_pathname(self, rel_to_scope=None):
        """ Return full pathname to this object, relative to scope
        *rel_to_scope*. If *rel_to_scope* is *None*, return the full pathname.
        return '.'.join([self._parent.get_pathname(rel_to_scope), self.name])

    def list_connections(self, is_hidden=False, show_expressions=False):
        """list all of the inputs and output connections of this PseudoComponent.
        If is_hidden is True, list the connections that a user would see
        if this PseudoComponent is hidden. If show_expressions is True (and
        only if is_hidden is also True) then list the connection expression
        that resulted in the creation of this PseudoComponent.
        if is_hidden:
            if self._outdests:
                if show_expressions:
                    return [self._expr_conn]
                    return [(src, self._outdests[0])
                               for src in self._inmap.keys() if src]
                return []
            conns = [(src, '.'.join([self.name, dest]))
                         for src, dest in self._inmap.items()]
            if self._outdests:
                conns.extend([('.'.join([self.name, 'out0']), dest)
                                           for dest in self._outdests])
        return conns

    def list_inputs(self):
        return self._inputs[:]

    def list_outputs(self):
        return ['out0']

    def list_comp_connections(self):
        """Return a list of connections between our pseudocomp and
        parent components of our sources/destinations.
        conns = [(src.split('.',1)[0], self.name)
                     for src, dest in self._inmap.items()]
        if self._outdests:
            conns.extend([(self.name, dest.split('.',1)[0])
                                    for dest in self._outdests])
        return conns

    def make_connections(self, scope):
        """Connect all of the inputs and outputs of this comp to
        the appropriate nodes in the dependency graph.
        for src, dest in self.list_connections():
            scope.connect(src, dest)

    def remove_connections(self, scope):
        """Disconnect all of the inputs and outputs of this comp
        from other nodes in the dependency graph.
        for src, dest in self.list_connections():
            scope.disconnect(src, dest)

    def invalidate_deps(self, varnames=None, force=False):
        self._valid = False
        return None

    def get_invalidation_type(self):
        return 'full'

    def connect(self, src, dest):
        self._valid = False

    def run(self, ffd_order=0, case_id=''):

        src = self._srcexpr.evaluate()
        setattr(self, 'out0', src)
        self._valid = True

    def update_inputs(self, inputs=None):

    def update_outputs(self, names):

    def get(self, name, index=None):
        if index is not None:
            raise RuntimeError("index not supported in PseudoComponent.get")
        return getattr(self, name)

    def set(self, path, value, index=None, src=None, force=False):
        if index is not None:
            raise ValueError("index not supported in PseudoComponent.set")
        setattr(self, path, value)

    def get_metadata(self, traitpath, metaname=None):
        if metaname is None:
            return self._meta[traitpath]
        return self._meta[traitpath].get(metaname)

    def is_valid(self):
        return self._valid

    def set_itername(self, itername):
        self._itername = itername

    def calc_derivatives(self, first=False, second=False, savebase=False,
                         required_inputs=None, required_outputs=None):
        if first:
            return self.provideJ()
        if second:
            msg = "2nd derivatives not supported in pseudocomponent %s"
            raise RuntimeError(msg % self.name)

    def provideJ(self):
        """Calculate analytical first derivatives."""

        if self.Jsize is None:
            n_in = 0
            n_out = 0
            for varname in self.list_inputs():
                val = self.get(varname)
                width = flattened_size(varname, val, self)
                n_in += width
            for varname in self.list_outputs():
                val = self.get(varname)
                width = flattened_size(varname, val, self)
                n_out += width
            self.Jsize = (n_out, n_in)

        J = zeros(self.Jsize)
        grad = self._srcexpr.evaluate_gradient()

        i = 0
        for varname in self._inputs:
            val = self.get(varname)
            width = flattened_size(varname, val, self)
            J[:, i:i+width] = grad[varname]
            i += width

        return J

    def list_deriv_vars(self):
        return tuple(self._inputs), ('out0',)
Example #4
class PseudoComponent(object):
    """A 'fake' component that is constructed from an ExprEvaluator.
    This fake component can be added to a dependency graph and executed
    along with 'real' components.
    implements(IComponent, IPseudoComp)

    def __init__(self,
        if destexpr is None:
            destexpr = DummyExpr()
        self._parent = None
        self.parent = parent
        self.name = _get_new_name(parent)
        self._inmap = {}  # mapping of component vars to our inputs
        self._meta = {}
        self._inputs = []
        self._initialized = False

        # Flags and caching used by the derivatives calculation
        self.force_fd = False
        self._provideJ_bounds = None

        self._pseudo_type = pseudo_type  # a string indicating the type of pseudocomp
        # this is, e.g., 'units', 'constraint', 'objective',
        # or 'multi_var_expr'
        self._subtype = subtype  # for constraints, 'equality' or 'inequality'

        self._exprobj = exprobject  # object responsible for creation of this pcomp, e.g. a Constraint

        self._orig_src = srcexpr.text
        self._orig_dest = destexpr.text
        self.Jsize = None
        self.mpi = MPI_info()

        varmap = {}
        rvarmap = {}
        for i, ref in enumerate(srcexpr.ordered_refs()):
            in_name = 'in%d' % i
            self._inmap[ref] = in_name
            varmap[ref] = in_name
            rvarmap.setdefault(_get_varname(ref), set()).add(ref)
            setattr(self, in_name, 0.)

        refs = list(destexpr.refs())
        if refs:
            if len(refs) == 1:
                setattr(self, 'out0', None)
                raise RuntimeError("output of PseudoComponent must reference"
                                   " only one variable")
        varmap[refs[0]] = 'out0'
        rvarmap.setdefault(_get_varname(refs[0]), set()).add(refs[0])

        for name, meta in srcexpr.get_metadata():
            for rname in rvarmap[name]:
                self._meta[varmap[rname]] = meta

        for name, meta in destexpr.get_metadata():
            for rname in rvarmap[name]:
                self._meta[varmap[rname]] = meta

        if translate:
            xformed_src = transform_expression(srcexpr.text, self._inmap)
            xformed_src = srcexpr.text

        out_units = self._meta['out0'].get('units')
        pq = None
        if out_units is not None:
            # evaluate the src expression using UnitsOnlyPQ objects

            tmpdict = {}

            # First, replace values with UnitsOnlyPQ objects
            for inp in self._inputs:
                units = self._meta[inp].get('units')
                if units:
                    tmpdict[inp] = UnitsOnlyPQ(0., units)
                    tmpdict[inp] = 0.

            pq = eval(xformed_src, _expr_dict, tmpdict)
            self._srcunits = pq.unit

            unitnode = ast.parse(xformed_src)
                unitxform = unit_xform(unitnode, self._srcunits, out_units)
            except Exception as err:
                raise TypeError("Incompatible units for '%s' and '%s': %s" %
                                (srcexpr.text, destexpr.text, err))
            unit_src = print_node(unitxform)
            xformed_src = unit_src
            self._srcunits = None

        self._srcexpr = ConnectedExprEvaluator(xformed_src, scope=self)

        # this is just the equation string (for debugging)
        if self._orig_dest:
            self._outdests = [self._orig_dest]
            if pq is None:
                sunit = dunit = ''
                sunit = "'%s'" % pq.get_unit_name()
                dunit = "'%s'" % out_units
            self._orig_expr = "%s %s -> %s %s" % (self._orig_src, sunit,
                                                  self._orig_dest, dunit)
            self._outdests = []
            self._orig_expr = self._orig_src

        self.missing_deriv_policy = 'error'
        self._negate = False

    def __getstate__(self):
        state = self.__dict__.copy()
        state['_parent'] = self.parent
        return state

    def __setstate__(self, state):
        self.parent = state['_parent']

    def parent(self):
        """ Our parent assembly. """
        return None if self._parent is None else self._parent()

    def parent(self, parent):
        self._parent = None if parent is None else weakref.ref(parent)

    def check_config(self, strict=False):

    def cpath_updated(self):

    def is_differentiable(self):
        """Return True if analytical derivatives can be
        computed for this Component.
        return True

    def get_pathname(self, rel_to_scope=None):
        """ Return full pathname to this object, relative to scope
        *rel_to_scope*. If *rel_to_scope* is *None*, return the full pathname.
        return '.'.join((self.parent.get_pathname(rel_to_scope), self.name))

    def list_connections(self, is_hidden=False, show_expressions=False):
        """list all of the inputs and output connections of this PseudoComponent.
        If is_hidden is True, list the connections that a user would see
        if this PseudoComponent is hidden. If show_expressions is True (and
        only if is_hidden is also True) then list the connection expression
        that resulted in the creation of this PseudoComponent.
        if is_hidden:
            if self._outdests:
                if show_expressions:
                    return [(self._orig_src, self._orig_dest)]
                    return [(src, self._outdests[0])
                            for src in self._inmap.keys() if src]
                return []
            conns = [(src, '.'.join((self.name, dest)))
                     for src, dest in self._inmap.items()]
            if self._outdests:
                conns.extend([('.'.join((self.name, 'out0')), dest)
                              for dest in self._outdests])
        return conns

    def list_inputs(self, connected=True):
        return self._inputs[:]

    def list_outputs(self, connected=True):
        return ['out0']

    def config_changed(self, update_parent=True):

    def list_comp_connections(self):
        """Return a list of connections between our pseudocomp and
        parent components of our sources/destinations.
        conns = [(src.split('.', 1)[0], self.name)
                 for src, dest in self._inmap.items()]
        if self._outdests:
            conns.extend([(self.name, dest.split('.', 1)[0])
                          for dest in self._outdests])
        return conns

    def contains(self, name):
        return name == 'out0' or name in self._inputs

    def make_connections(self, scope, driver=None):
        """Connect all of the inputs and outputs of this comp to
        the appropriate nodes in the dependency graph.
        for src, dest in self.list_connections():
            scope.connect(src, dest)

        if driver is not None:
            scope._depgraph.add_driver_input(driver.name, self.name + '.out0')

    def run(self, case_uuid=''):
        if self._negate:
            setattr(self, 'out0', -self._srcexpr.evaluate())
            setattr(self, 'out0', self._srcexpr.evaluate())

    def evaluate(self):
        if self._negate:
            setattr(self, 'out0', -self._srcexpr.evaluate())
            setattr(self, 'out0', self._srcexpr.evaluate())

    def get(self, name):
        return getattr(self, name)

    def set(self, path, value):
        setattr(self, path, value)

    def get_metadata(self, traitpath, metaname=None):
        if metaname is None:
            return self._meta[traitpath]
        childname, _, restofpath = traitpath.partition('.')
        if restofpath:
            return getattr(self, childname).get_metadata(restofpath, metaname)
            return self._meta[traitpath].get(metaname)

    def set_itername(self, itername):
        self._itername = itername

    def linearize(self, first=False, second=False):
        """Component wrapper for the ProvideJ hook."""
        if first:
            return self.provideJ()
        if second:
            msg = "2nd derivatives not supported in pseudocomponent %s"
            raise RuntimeError(msg % self.name)

    def provideJ(self):
        """Calculate analytical first derivatives."""

        if self.Jsize is None:
            n_in = 0
            n_out = 0
            for varname in self.list_inputs():
                val = self.get(varname)
                width = flattened_size(varname, val, self)
                n_in += width
            for varname in self.list_outputs():
                val = self.get(varname)
                width = flattened_size(varname, val, self)
                n_out += width
            self.Jsize = (n_out, n_in)

        J = zeros(self.Jsize)
        grad = self._srcexpr.evaluate_gradient()

        i = 0
        for varname in self._inputs:
            val = self.get(varname)
            width = flattened_size(varname, val, self)
            J[:, i:i + width] = grad[varname]
            i += width

        if self._negate:
            return -J
            return J

    def ensure_init(self):
        """Make sure our inputs and outputs have been
        if not self._initialized:
            # set the current value of the connected variable
            # into our input
            for ref, in_name in self._inmap.items():
                setattr(self, in_name,
                if has_interface(getattr(self, in_name), IContainer):
                    getattr(self, in_name).name = in_name

            # set the initial value of the output
            setattr(self, 'out0', self._srcexpr.evaluate())

            self._initialized = True

    def list_deriv_vars(self):
        return tuple(self._inputs), ('out0', )

    def get_req_cpus(self):
        return 1

    def pre_setup(self):

    def setup_depgraph(self, dgraph):

    def setup_systems(self):
        return ()

    def setup_communicators(self, comm, scope=None):
        self.mpi.comm = comm

    def setup_variables(self):

    def setup_sizes(self):

    def setup_vectors(self, arrays=None):

    def post_setup(self):

    def get_flattened_value(self, path):
        """Return the named value, which may include
        an array index, as a flattened array of floats.  If
        the value is not flattenable into an array of floats,
        raise a TypeError.
        val, idx = get_val_and_index(self, path)
        return flattened_value(path, val)

    def set_flattened_value(self, path, value):
        val, rop = deep_getattr(self, path.split('[', 1)[0])
        idx = get_index(path)
        if isinstance(val, int_types):
            pass  # fall through to exception
        if isinstance(val, complex_or_real_types):
            if idx is None:
                setattr(self, path, value[0])
            # else, fall through to error
        elif isinstance(val, ndarray):
            if idx is None:
                setattr(self, path, value)
                val[idx] = value
        elif IVariableTree.providedBy(val):
            raise NotImplementedError(
                "no support for setting flattened values into vartrees")

        raise TypeError("%s: Failed to set flattened value to variable %s" %
                        (self.name, path))

    def get_req_default(self, self_reqired=None):
        return []

    def _input_updated(self, name, fullpath=None):

    def get_full_nodeset(self):
        """Return the full set of nodes in the depgraph
        belonging to this component.
        return set((self.name, ))

    def applyJ(self, system, variables):
        """ Wrapper for component derivative specification methods.
        Forward Mode.
        applyJ(system, variables)

    def applyJT(self, system, variables):
        """ Wrapper for component derivative specification methods.
        Adjoint Mode.
        applyJT(system, variables)

    def raise_exception(self, msg, exception_class=Exception):
        """Raise an exception."""
        coords = ''
        obj = self
        while obj is not None:
                coords = obj.get_itername()
            except AttributeError:
                    obj = obj.parent
                except AttributeError:
        if coords:
            full_msg = '%s (%s): %s' % (self.get_pathname(), coords, msg)
            full_msg = '%s: %s' % (self.get_pathname(), msg)
        raise exception_class(full_msg)