Ejemplo n.º 1
0
 def test_get_referenced_varpaths(self):
     ex = ExprEvaluator('comp.x[0] = 10*(3.2+ a.a1d[3]* 1.1*a.a1d[2 ])', self.top.a)
     self.assertEqual(ex.get_referenced_varpaths(), set(['comp.x','a.a1d']))
     ex.text = 'comp.contlist[1].a2d[2][1]'
     self.assertEqual(ex.get_referenced_varpaths(), set(['comp.contlist']))
     ex.scope = self.top.comp
     ex.text = 'comp.contlist[1]'
     self.assertEqual(ex.get_referenced_varpaths(), set(['comp.contlist']))
 def test_get_referenced_varpaths(self):
     ex = ExprEvaluator('comp.x[0] = 10*(3.2+ a.a1d[3]* 1.1*a.a1d[2 ])',
                        self.top.a)
     self.assertEqual(ex.get_referenced_varpaths(), set(['comp.x',
                                                         'a.a1d']))
     ex.text = 'comp.contlist[1].a2d[2][1]'
     self.assertEqual(ex.get_referenced_varpaths(), set(['comp.contlist']))
     ex.scope = self.top.comp
     ex.text = 'comp.contlist[1]'
     self.assertEqual(ex.get_referenced_varpaths(), set(['comp.contlist']))
Ejemplo n.º 3
0
 def __init__(self, exprs=(), derivatives=(), sleep=0, dsleep=0):
     super(ExecCompWithDerivatives, self).__init__()
     
     ins = set()
     outs = set()
     allvars = set()
     self.exprs = exprs
     self.codes = [compile(expr,'<string>','exec') for expr in exprs]
     self.sleep = sleep
     self.dsleep = dsleep
     
     for expr in exprs:
         lhs,rhs = expr.split('=')
         lhs = lhs.strip()
         lhs = lhs.split(',')
         outs.update(lhs)
         expreval = ExprEvaluator(expr, scope=self)
         allvars.update(expreval.get_referenced_varpaths(copy=False))
     ins = allvars - outs
     
     for var in allvars:
         if '.' not in var:  # if a varname has dots, it's outside of our scope,
                             # so don't add a trait for it
             if var in outs:
                 iotype = 'out'
             else:
                 iotype = 'in'
             self.add(var, Float(0.0, iotype=iotype))
 
     self.deriv_exprs = derivatives
     self.derivative_codes = \
         [compile(expr,'<string>','exec') for expr in derivatives]
     
     self.derivative_names = []
     regex = re.compile('d(.*)_d(.*)')
     for expr in derivatives:
         expreval = ExprEvaluator(expr, scope=self)
         exvars = expreval.get_referenced_varpaths(copy=False)
         
         lhs, _ = expr.split('=')
         lhs = lhs.strip()
         allvars.add(lhs)
         
         # Check for undefined vars the cool way with sets
         if len(exvars-allvars) > 0:
             self.raise_exception('derivative references a variable '
                                  'that is not defined in exprs',
                                  ValueError)
                 
         names = regex.findall(lhs)
         num = names[0][0]
         wrt = names[0][1]
         self.derivatives.declare_first_derivative(num, wrt)
         
         self.derivative_names.append( (lhs, num, wrt) )
Ejemplo n.º 4
0
 def __init__(self, exprs=(), derivatives=(), sleep=0, dsleep=0):
     super(ExecCompWithDerivatives, self).__init__()
     
     ins = set()
     outs = set()
     allvars = set()
     self.exprs = exprs
     self.codes = [compile(expr,'<string>','exec') for expr in exprs]
     self.sleep = sleep
     self.dsleep = dsleep
     
     for expr in exprs:
         lhs,rhs = expr.split('=')
         lhs = lhs.strip()
         lhs = lhs.split(',')
         outs.update(lhs)
         expreval = ExprEvaluator(expr, scope=self)
         allvars.update(expreval.get_referenced_varpaths(copy=False))
     ins = allvars - outs
     
     for var in allvars:
         if '.' not in var:  # if a varname has dots, it's outside of our scope,
                             # so don't add a trait for it
             if var in outs:
                 iotype = 'out'
             else:
                 iotype = 'in'
             self.add(var, Float(0.0, iotype=iotype))
 
     self.deriv_exprs = derivatives
     self.derivative_codes = \
         [compile(expr,'<string>','exec') for expr in derivatives]
     
     self.derivative_names = []
     regex = re.compile('d(.*)_d(.*)')
     for expr in derivatives:
         expreval = ExprEvaluator(expr, scope=self)
         exvars = expreval.get_referenced_varpaths(copy=False)
         
         lhs, _ = expr.split('=')
         lhs = lhs.strip()
         allvars.add(lhs)
         
         # Check for undefined vars the cool way with sets
         if len(exvars-allvars) > 0:
             self.raise_exception('derivative references a variable '
                                  'that is not defined in exprs',
                                  ValueError)
                 
         names = regex.findall(lhs)
         num = names[0][0]
         wrt = names[0][1]
         self.derivatives.declare_first_derivative(num, wrt)
         
         self.derivative_names.append( (lhs, num, wrt) )
Ejemplo n.º 5
0
    def _translate_up(self, text, node):
        if is_legal_name(text):
            return '.'.join([node, text])

        expr = ExprEvaluator(text)
        varpath = expr.get_referenced_varpaths().pop()
        return transform_expression(text, { varpath: '.'.join([node, varpath]) })
Ejemplo n.º 6
0
def _split_expr(text):
    """Take an expression string and return varpath, expr"""
    if text.startswith('@') or is_legal_name(text):
        return text, text
    
    expr = ExprEvaluator(text)
    return expr.get_referenced_varpaths().pop(), text
Ejemplo n.º 7
0
    def _translate_up(self, text, node):
        """Upscoping"""
        if is_legal_name(text):
            return '.'.join([node, text])

        expr = ExprEvaluator(text)
        varpath = expr.get_referenced_varpaths().pop()
        return transform_expression(text, { varpath: '.'.join([node, varpath]) })
Ejemplo n.º 8
0
 def __init__(self, exprs=()):
     super(ExecComp, self).__init__()
     ins = set()
     outs = set()
     allvars = set()
     self.codes = [compile(expr,'<string>','exec') for expr in exprs]
     for expr in exprs:
         lhs,rhs = expr.split('=')
         lhs = lhs.strip()
         lhs = lhs.split(',')
         outs.update(lhs)
         expreval = ExprEvaluator(expr, scope=self)
         allvars.update(expreval.get_referenced_varpaths(copy=False))
     ins = allvars - outs
     for var in allvars:
         if '.' not in var:  # if a varname has dots, it's outside of our scope,
                             # so don't add a trait for it
             if var in outs:
                 iotype = 'out'
             else:
                 iotype = 'in'
             self.add(var, Float(iotype=iotype))
Ejemplo n.º 9
0
 def __init__(self, exprs=[]):
     super(ExecComp, self).__init__()
     ins = set()
     outs = set()
     allvars = set()
     self.codes = [compile(expr, '<string>', 'exec') for expr in exprs]
     for expr in exprs:
         expreval = ExprEvaluator(expr, scope=self)
         exvars = expreval.get_referenced_varpaths()
         lhs, rhs = expr.split('=')
         lhs = lhs.strip()
         lhs = lhs.split(',')
         outs.update(lhs)
         allvars.update(exvars)
     ins = allvars - outs
     for var in allvars:
         if '.' not in var:  # if a varname has dots, it's outside of our scope,
             # so don't add a trait for it
             if var in outs:
                 iotype = 'out'
             else:
                 iotype = 'in'
             self.add(var, Float(iotype=iotype))
Ejemplo n.º 10
0
 def __init__(self, exprs=(), sleep=0):
     super(ExecComp, self).__init__()
     ins = set()
     outs = set()
     allvars = set()
     self.exprs = exprs
     self.codes = [compile(expr, "<string>", "exec") for expr in exprs]
     self.sleep = sleep
     for expr in exprs:
         lhs, rhs = expr.split("=")
         lhs = lhs.strip()
         lhs = lhs.split(",")
         outs.update(lhs)
         expreval = ExprEvaluator(expr, scope=self)
         allvars.update(expreval.get_referenced_varpaths(copy=False))
     ins = allvars - outs
     for var in allvars:
         if "." not in var:  # if a varname has dots, it's outside of our scope,
             # so don't add a trait for it
             if var in outs:
                 iotype = "out"
             else:
                 iotype = "in"
             self.add(var, Float(0.0, iotype=iotype))
Ejemplo n.º 11
0
    def __init__(self, exprs=(), sleep=0, trace=False):
        super(ExecComp, self).__init__()
        outs = set()
        allvars = set()
        self.exprs = exprs
        self.codes = [compile(expr, '<string>', 'exec') for expr in exprs]
        self.sleep = sleep
        self.trace = trace
        for expr in exprs:
            lhs, rhs = expr.split('=')
            lhs = lhs.strip()
            lhs = lhs.split(',')
            outs.update(lhs)
            expreval = ExprEvaluator(expr, scope=self)
            allvars.update(expreval.get_referenced_varpaths(copy=False))

        for var in allvars:
            if '.' not in var:  # if a varname has dots, it's outside of our scope,
                # so don't add a trait for it
                if var in outs:
                    iotype = 'out'
                else:
                    iotype = 'in'
                self.add(var, Float(0.0, iotype=iotype))
Ejemplo n.º 12
0
class Constraint(object):
    """ Object that stores info for a single constraint. """
    def __init__(self, lhs, comparator, rhs, scope):
        self.lhs = ExprEvaluator(lhs, scope=scope)
        unresolved_vars = self.lhs.get_unresolved()

        if unresolved_vars:
            msg = "Left hand side of constraint '{0}' has invalid variables {1}"
            expression = ' '.join((lhs, comparator, rhs))

            raise ExprEvaluator._invalid_expression_error(unresolved_vars,
                                                          expr=expression,
                                                          msg=msg)
        self.rhs = ExprEvaluator(rhs, scope=scope)
        unresolved_vars = self.rhs.get_unresolved()

        if unresolved_vars:
            msg = "Right hand side of constraint '{0}' has invalid variables {1}"
            expression = ' '.join((lhs, comparator, rhs))

            raise ExprEvaluator._invalid_expression_error(unresolved_vars,
                                                          expr=expression,
                                                          msg=msg)
        self.comparator = comparator
        self.pcomp_name = None
        self._size = None

        # Linear flag: constraints are nonlinear by default
        self.linear = False

    @property
    def size(self):
        """Total scalar items in this constraint."""
        if self._size is None:
            self._size = len(self.evaluate(self.lhs.scope))
        return self._size

    def activate(self):
        """Make this constraint active by creating the appropriate
        connections in the dependency graph.
        """
        if self.pcomp_name is None:
            pseudo = PseudoComponent(self.lhs.scope,
                                     self._combined_expr(),
                                     pseudo_type='constraint')
            self.pcomp_name = pseudo.name
            self.lhs.scope.add(pseudo.name, pseudo)
        getattr(self.lhs.scope, pseudo.name).make_connections(self.lhs.scope)

    def deactivate(self):
        """Remove this constraint from the dependency graph and remove
        its pseudocomp from the scoping object.
        """
        if self.pcomp_name:
            scope = self.lhs.scope
            try:
                pcomp = getattr(scope, self.pcomp_name)
            except AttributeError:
                pass
            else:
                scope.remove(pcomp.name)
            finally:
                self.pcomp_name = None

    def _combined_expr(self):
        """Given a constraint object, take the lhs, operator, and
        rhs and combine them into a single expression by moving rhs
        terms over to the lhs.  For example,
        for the constraint 'C1.x < C2.y + 7', return the expression
        'C1.x - C2.y - 7'.  Depending on the direction of the operator,
        the sign of the expression may be flipped.  The final form of
        the constraint, when evaluated, will be considered to be satisfied
        if it evaluates to a value <= 0.
        """
        scope = self.lhs.scope

        if self.comparator.startswith('>'):
            first = self.rhs.text
            second = self.lhs.text
        else:
            first = self.lhs.text
            second = self.rhs.text

        first_zero = False
        try:
            f = float(first)
        except Exception:
            pass
        else:
            if f == 0:
                first_zero = True

        second_zero = False
        try:
            f = float(second)
        except Exception:
            pass
        else:
            if f == 0:
                second_zero = True

        if first_zero:
            newexpr = "-(%s)" % second
        elif second_zero:
            newexpr = "%s" % first
        else:
            newexpr = '%s-(%s)' % (first, second)

        return ExprEvaluator(newexpr, scope)

    def copy(self):
        """ Returns a copy of our self. """
        return Constraint(str(self.lhs),
                          self.comparator,
                          str(self.rhs),
                          scope=self.lhs.scope)

    def evaluate(self, scope):
        """Returns the value of the constraint as a sequence."""
        pcomp = getattr(scope, self.pcomp_name)
        val = pcomp.out0

        if isinstance(val, ndarray):
            return val.flatten()
        else:
            return [val]

    def evaluate_gradient(self, scope, stepsize=1.0e-6, wrt=None):
        """Returns the gradient of the constraint eq/ineq as a tuple of the
        form (lhs, rhs, comparator, is_violated)."""

        lhs = self.lhs.evaluate_gradient(scope=scope,
                                         stepsize=stepsize,
                                         wrt=wrt)
        if isinstance(self.rhs, float):
            rhs = 0.
        else:
            rhs = self.rhs.evaluate_gradient(scope=scope,
                                             stepsize=stepsize,
                                             wrt=wrt)

        return (lhs, rhs, self.comparator, not _ops[self.comparator](lhs, rhs))

    def get_referenced_compnames(self):
        """Returns a set of names of each component referenced by this
        constraint.
        """
        if isinstance(self.rhs, float):
            return self.lhs.get_referenced_compnames()
        else:
            return self.lhs.get_referenced_compnames().union(
                self.rhs.get_referenced_compnames())

    def get_referenced_varpaths(self, copy=True):
        """Returns a set of names of each component referenced by this
        constraint.
        """
        if isinstance(self.rhs, float):
            return self.lhs.get_referenced_varpaths(copy=copy)
        else:
            return self.lhs.get_referenced_varpaths(copy=copy).union(
                self.rhs.get_referenced_varpaths(copy=copy))

    def __str__(self):
        return ' '.join((str(self.lhs), self.comparator, str(self.rhs)))

    def __eq__(self, other):
        if not isinstance(other, Constraint):
            return False
        return (self.lhs, self.comparator, self.rhs) == \
               (other.lhs, other.comparator, other.rhs)
class Constraint(object):
    """ Object that stores info for a single constraint. """

    def __init__(self, lhs, comparator, rhs, scope, jacs=None):
        self.lhs = ExprEvaluator(lhs, scope=scope)
        self._pseudo = None
        self.pcomp_name = None
        unresolved_vars = self.lhs.get_unresolved()

        if unresolved_vars:
            msg = "Left hand side of constraint '{0}' has invalid variables {1}"
            expression = ' '.join((lhs, comparator, rhs))

            raise ExprEvaluator._invalid_expression_error(unresolved_vars,
                                                          expr=expression,
                                                          msg=msg)
        self.rhs = ExprEvaluator(rhs, scope=scope)
        unresolved_vars = self.rhs.get_unresolved()

        if unresolved_vars:
            msg = "Right hand side of constraint '{0}' has invalid variables {1}"
            expression = ' '.join((lhs, comparator, rhs))

            raise ExprEvaluator._invalid_expression_error(unresolved_vars,
                                                          expr=expression,
                                                          msg=msg)
        self.comparator = comparator
        self._size = None

        # Linear flag: constraints are nonlinear by default
        self.linear = False

        # User-defined jacobian function
        self.jacs = jacs

        self._create_pseudo()

    @property
    def size(self):
        """Total scalar items in this constraint."""
        if self._size is None:
            self._size = len(self.evaluate(self.lhs.scope))
        return self._size

    def _create_pseudo(self):
        """Create our pseudo component."""
        if self.comparator == '=':
            subtype = 'equality'
        else:
            subtype = 'inequality'

        # check for simple structure of equality constraint,
        # either
        #     var1 = var2
        #  OR
        #     var1 - var2 = 0
        #  OR
        #     var1 = 0
        lrefs = list(self.lhs.ordered_refs())
        rrefs = list(self.rhs.ordered_refs())

        try:
            leftval = float(self.lhs.text)
        except ValueError:
            leftval = None

        try:
            rightval = float(self.rhs.text)
        except ValueError:
            rightval = None

        pseudo_class = PseudoComponent

        if self.comparator == '=':
            # look for var1-var2=0
            if len(lrefs) == 2 and len(rrefs) == 0:
                if rightval == 0. and \
                        _remove_spaces(self.lhs.text) == \
                            lrefs[0]+'-'+lrefs[1]:
                    pseudo_class = SimpleEQConPComp
            # look for 0=var1-var2
            elif len(lrefs) == 0 and len(rrefs) == 2:
                if leftval==0. and \
                       _remove_spaces(self.rhs.text) == \
                            rrefs[0]+'-'+rrefs[1]:
                    pseudo_class = SimpleEQConPComp
            # look for var1=var2
            elif len(lrefs) == 1 and len(rrefs) == 1:
                if lrefs[0] == self.lhs.text and \
                           rrefs[0] == self.rhs.text:
                    pseudo_class = SimpleEQConPComp
            # look for var1=0
            elif len(lrefs) == 1 and len(rrefs) == 0 and rightval is not None:
                pseudo_class = SimpleEQ0PComp

        self._pseudo = pseudo_class(self.lhs.scope,
                                    self._combined_expr(),
                                    pseudo_type='constraint',
                                    subtype=subtype,
                                    exprobject=self)

        self.pcomp_name = self._pseudo.name

    def activate(self, driver):
        """Make this constraint active by creating the appropriate
        connections in the dependency graph.
        """
        self._pseudo.activate(self.lhs.scope, driver)

    def deactivate(self):
        """Remove this constraint from the dependency graph and remove
        its pseudocomp from the scoping object.
        """
        if self._pseudo is not None:
            scope = self.lhs.scope
            try:
                pcomp = getattr(scope, self._pseudo.name)
            except AttributeError:
                pass
            else:
                scope.remove(self._pseudo.name)

    def _combined_expr(self):
        """Given a constraint object, take the lhs, operator, and
        rhs and combine them into a single expression by moving rhs
        terms over to the lhs.  For example,
        for the constraint 'C1.x < C2.y + 7', return the expression
        'C1.x - C2.y - 7'.  Depending on the direction of the operator,
        the sign of the expression may be flipped.  The final form of
        the constraint, when evaluated, will be considered to be satisfied
        if it evaluates to a value <= 0.
        """
        scope = self.lhs.scope

        if self.comparator.startswith('>'):
            first = self.rhs.text
            second = self.lhs.text
        else:
            first = self.lhs.text
            second = self.rhs.text

        first_zero = False
        try:
            f = float(first)
        except Exception:
            pass
        else:
            if f == 0:
                first_zero = True

        second_zero = False
        try:
            f = float(second)
        except Exception:
            pass
        else:
            if f == 0:
                second_zero = True

        if first_zero:
            newexpr = "-(%s)" % second
        elif second_zero:
            newexpr = first
        else:
            newexpr = '%s-(%s)' % (first, second)

        return ExprEvaluator(newexpr, scope)

    def copy(self):
        """ Returns a copy of our self. """
        return Constraint(str(self.lhs), self.comparator, str(self.rhs),
                          scope=self.lhs.scope, jacs=self.jacs)

    def evaluate(self, scope):
        """Returns the value of the constraint as a sequence."""
        vname = self.pcomp_name + '.out0'
        try:
            system = getattr(scope,self.pcomp_name)._system
            info = system.vec['u']._info[scope.name2collapsed[vname]]
            # if a pseudocomp output is marked as hidden, that means that
            # it's really a residual, but it's mapped in the vector to
            # the corresponding state, so don't pull that value because
            # we want the actual residual value
            if info.hide: # it's a residual so pull from f vector
                return -system.vec['f'][scope.name2collapsed[vname]]
            else:
                return info.view
        except (KeyError, AttributeError):
            pass

        val = getattr(scope, self.pcomp_name).out0

        if isinstance(val, ndarray):
            return val.flatten()
        else:
            return [val]

    def get_referenced_compnames(self):
        """Returns a set of names of each component referenced by this
        constraint.
        """
        if isinstance(self.rhs, float):
            return self.lhs.get_referenced_compnames()
        else:
            return self.lhs.get_referenced_compnames().union(
                                            self.rhs.get_referenced_compnames())

    def get_referenced_varpaths(self, copy=True, refs=False):
        """Returns a set of names of each component referenced by this
        constraint.
        """
        if isinstance(self.rhs, float):
            return self.lhs.get_referenced_varpaths(copy=copy, refs=refs)
        else:
            return self.lhs.get_referenced_varpaths(copy=copy, refs=refs).union(
                    self.rhs.get_referenced_varpaths(copy=copy, refs=refs))

    def check_resolve(self):
        """Returns True if this constraint has no unresolved references."""
        return self.lhs.check_resolve() and self.rhs.check_resolve()

    def get_unresolved(self):
        return list(set(self.lhs.get_unresolved()).union(self.rhs.get_unresolved()))

    def name_changed(self, old, new):
        """Update expressions if necessary when an object is renamed."""
        self.rhs.name_changed(old, new)
        self.lhs.name_changed(old, new)

    def __str__(self):
        return ' '.join((str(self.lhs), self.comparator, str(self.rhs)))

    def __eq__(self, other):
        if not isinstance(other, Constraint):
            return False
        return (self.lhs, self.comparator, self.rhs) == \
               (other.lhs, other.comparator, other.rhs)
class Constraint2Sided(Constraint):
    """ Object that stores info for a double-sided constraint. """

    def __init__(self, lhs, center, rhs, comparator, scope, jacs=None):
        self.lhs = ExprEvaluator(lhs, scope=scope)
        unresolved_vars = self.lhs.get_unresolved()

        self._pseudo = None
        self.pcomp_name = None

        if unresolved_vars:
            msg = "Left hand side of constraint '{0}' has invalid variables {1}"
            expression = ' '.join((lhs, comparator, center, comparator,
                                   rhs))

            raise ExprEvaluator._invalid_expression_error(unresolved_vars,
                                                          expr=expression,
                                                          msg=msg)
        self.center = ExprEvaluator(center, scope=scope)
        unresolved_vars = self.center.get_unresolved()

        if unresolved_vars:
            msg = "Center of constraint '{0}' has invalid variables {1}"
            expression = ' '.join((lhs, comparator, center, comparator,
                                   rhs))

            raise ExprEvaluator._invalid_expression_error(unresolved_vars,
                                                          expr=expression,
                                                          msg=msg)
        self.rhs = ExprEvaluator(rhs, scope=scope)
        unresolved_vars = self.rhs.get_unresolved()

        if unresolved_vars:
            msg = "Right hand side of constraint '{0}' has invalid variables {1}"
            expression = ' '.join((lhs, comparator, center, comparator,
                                   rhs))

            raise ExprEvaluator._invalid_expression_error(unresolved_vars,
                                                          expr=expression,
                                                          msg=msg)
        self.comparator = comparator
        self._size = None

        # Linear flag: constraints are nonlinear by default
        self.linear = False

        self.low = self.lhs.evaluate()
        self.high = self.rhs.evaluate()

        # User-defined jacobian function
        self.jacs = jacs

        self._create_pseudo()

    def _create_pseudo(self):
        """Create our pseudo component."""
        scope = self.lhs.scope
        refs = list(self.center.ordered_refs())
        pseudo_class = PseudoComponent

        # look for a<var1<b
        if len(refs) == 1 and self.center.text == refs[0]:
            pseudo_class = SimpleEQ0PComp

        self._pseudo = pseudo_class(scope,
                                    self.center,
                                    pseudo_type='constraint',
                                    subtype='inequality',
                                    exprobject=self)

        self.pcomp_name = self._pseudo.name

    def _combined_expr(self):
        """Only need the center expression
        """
        return self.center

    def copy(self):
        """ Returns a copy of our self. """
        return Constraint2Sided(str(self.lhs), str(self.center), str(self.rhs),
                          self.comparator, scope=self.lhs.scope,
                          jacs=self.jacs)

    def get_referenced_compnames(self):
        """Returns a set of names of each component referenced by this
        constraint.
        """
        return self.center.get_referenced_compnames()

    def get_referenced_varpaths(self, copy=True, refs=False):
        """Returns a set of names of each component referenced by this
        constraint.
        """
        return self.center.get_referenced_varpaths(copy=copy, refs=refs)

    def name_changed(self, old, new):
        """Update expressions if necessary when an object is renamed."""
        self.rhs.name_changed(old, new)
        self.lhs.name_changed(old, new)
        self.center.name_changed(old, new)

    def __str__(self):
        return ' '.join((str(self.lhs), str(self.center), str(self.rhs), self.comparator))

    def __eq__(self, other):
        if not isinstance(other, Constraint2Sided):
            return False
        return (self.lhs, self.center, self.comparator, self.rhs) == \
               (other.lhs, self.center, other.comparator, other.rhs)
class Constraint(object):
    """ Object that stores info for a single constraint. """

    def __init__(self, lhs, comparator, rhs, scope):
        self.lhs = ExprEvaluator(lhs, scope=scope)
        if not self.lhs.check_resolve():
            raise ValueError("Constraint '%s' has an invalid left-hand-side."
                              % ' '.join([lhs, comparator, rhs]))

        self.comparator = comparator

        self.rhs = ExprEvaluator(rhs, scope=scope)
        if not self.rhs.check_resolve():
            raise ValueError("Constraint '%s' has an invalid right-hand-side."
                              % ' '.join([lhs, comparator, rhs]))

        self.pcomp_name = None
        self._size = None

    @property
    def size(self):
        """Total scalar items in this constraint."""
        if self._size is None:
            self._size = len(self.evaluate(self.lhs.scope))
        return self._size

    def activate(self):
        """Make this constraint active by creating the appropriate
        connections in the dependency graph.
        """
        if self.pcomp_name is None:
            pseudo = PseudoComponent(self.lhs.scope, self._combined_expr(),
                                     pseudo_type='constraint')
            self.pcomp_name = pseudo.name
            self.lhs.scope.add(pseudo.name, pseudo)
        getattr(self.lhs.scope, pseudo.name).make_connections(self.lhs.scope)

    def deactivate(self):
        """Remove this constraint from the dependency graph and remove
        its pseudocomp from the scoping object.
        """
        if self.pcomp_name:
            scope = self.lhs.scope
            try:
                pcomp = getattr(scope, self.pcomp_name)
            except AttributeError:
                pass
            else:
                # pcomp.remove_connections(scope)
                # if hasattr(scope, pcomp.name):
                scope.remove(pcomp.name)
            finally:  
                self.pcomp_name = None

    def _combined_expr(self):
        """Given a constraint object, take the lhs, operator, and
        rhs and combine them into a single expression by moving rhs
        terms over to the lhs.  For example,
        for the constraint 'C1.x < C2.y + 7', return the expression
        'C1.x - C2.y - 7'.  Depending on the direction of the operator,
        the sign of the expression may be flipped.  The final form of
        the constraint, when evaluated, will be considered to be satisfied
        if it evaluates to a value <= 0.
        """
        scope = self.lhs.scope

        if self.comparator.startswith('>'):
            first = self.rhs.text
            second = self.lhs.text
        else:
            first = self.lhs.text
            second = self.rhs.text

        first_zero = False
        try:
            f = float(first)
        except Exception:
            pass
        else:
            if f == 0:
                first_zero = True

        second_zero = False
        try:
            f = float(second)
        except Exception:
            pass
        else:
            if f == 0:
                second_zero = True

        if first_zero:
            newexpr = "-(%s)" % second
        elif second_zero:
            newexpr = "%s" % first
        else:
            newexpr = '%s-(%s)' % (first, second)

        return ExprEvaluator(newexpr, scope)

    def copy(self):
        return Constraint(str(self.lhs), self.comparator, str(self.rhs),
                          scope=self.lhs.scope)

    def evaluate(self, scope):
        """Returns the value of the constraint as a sequence."""
        pcomp = getattr(scope, self.pcomp_name)
        if not pcomp.is_valid():
            pcomp.update_outputs(['out0'])
        val = pcomp.out0

        if isinstance(val, ndarray):
            return val.flatten()
        else:
            return [val]

    def evaluate_gradient(self, scope, stepsize=1.0e-6, wrt=None):
        """Returns the gradient of the constraint eq/ineq as a tuple of the
        form (lhs, rhs, comparator, is_violated)."""

        lhs = self.lhs.evaluate_gradient(scope=scope, stepsize=stepsize, wrt=wrt)
        if isinstance(self.rhs, float):
            rhs = 0.
        else:
            rhs = self.rhs.evaluate_gradient(scope=scope, stepsize=stepsize, wrt=wrt)

        return (lhs, rhs, self.comparator, not _ops[self.comparator](lhs, rhs))

    def get_referenced_compnames(self):
        """Returns a set of names of each component referenced by this
        constraint.
        """
        if isinstance(self.rhs, float):
            return self.lhs.get_referenced_compnames()
        else:
            return self.lhs.get_referenced_compnames().union(
                                            self.rhs.get_referenced_compnames())

    def get_referenced_varpaths(self, copy=True):
        """Returns a set of names of each component referenced by this
        constraint.
        """
        if isinstance(self.rhs, float):
            return self.lhs.get_referenced_varpaths(copy=copy)
        else:
            return self.lhs.get_referenced_varpaths(copy=copy).union(
                                    self.rhs.get_referenced_varpaths(copy=copy))

    def __str__(self):
        return ' '.join([str(self.lhs), self.comparator, str(self.rhs)])

    def __eq__(self, other):
        if not isinstance(other, Constraint):
            return False
        return (self.lhs, self.comparator, self.rhs) == \
               (other.lhs, other.comparator, other.rhs)
class Constraint(object):
    """ Object that stores info for a single constraint. """
    def __init__(self, lhs, comparator, rhs, scope):
        self.lhs = ExprEvaluator(lhs, scope=scope)
        unresolved_vars = self.lhs.get_unresolved()

        if unresolved_vars:
            msg = "Left hand side of constraint '{0}' has invalid variables {1}"
            expression = ' '.join((lhs, comparator, rhs))

            raise ExprEvaluator._invalid_expression_error(unresolved_vars,
                                                          expr=expression,
                                                          msg=msg)
        self.rhs = ExprEvaluator(rhs, scope=scope)
        unresolved_vars = self.rhs.get_unresolved()

        if unresolved_vars:
            msg = "Right hand side of constraint '{0}' has invalid variables {1}"
            expression = ' '.join((lhs, comparator, rhs))

            raise ExprEvaluator._invalid_expression_error(unresolved_vars,
                                                          expr=expression,
                                                          msg=msg)
        self.comparator = comparator
        self.pcomp_name = None
        self._size = None

        # Linear flag: constraints are nonlinear by default
        self.linear = False

    @property
    def size(self):
        """Total scalar items in this constraint."""
        if self._size is None:
            self._size = len(self.evaluate(self.lhs.scope))
        return self._size

    def activate(self, driver):
        """Make this constraint active by creating the appropriate
        connections in the dependency graph.
        """
        if self.pcomp_name is None:
            if self.comparator == '=':
                subtype = 'equality'
            else:
                subtype = 'inequality'

            # check for simple structure of equality constraint,
            # either
            #     var1 = var2
            #  OR
            #     var1 - var2 = 0
            #  OR
            #     var1 = 0
            lrefs = list(self.lhs.ordered_refs())
            rrefs = list(self.rhs.ordered_refs())

            try:
                leftval = float(self.lhs.text)
            except ValueError:
                leftval = None

            try:
                rightval = float(self.rhs.text)
            except ValueError:
                rightval = None

            pseudo_class = PseudoComponent

            if self.comparator == '=':
                # look for var1-var2=0
                if len(lrefs) == 2 and len(rrefs) == 0:
                    if rightval == 0. and \
                            _remove_spaces(self.lhs.text) == \
                                lrefs[0]+'-'+lrefs[1]:
                        pseudo_class = SimpleEQConPComp
                # look for 0=var1-var2
                elif len(lrefs) == 0 and len(rrefs) == 2:
                    if leftval==0. and \
                           _remove_spaces(self.rhs.text) == \
                                rrefs[0]+'-'+rrefs[1]:
                        pseudo_class = SimpleEQConPComp
                # look for var1=var2
                elif len(lrefs) == 1 and len(rrefs) == 1:
                    if lrefs[0] == self.lhs.text and \
                               rrefs[0] == self.rhs.text:
                        pseudo_class = SimpleEQConPComp
                # look for var1=0
                elif len(lrefs) == 1 and len(
                        rrefs) == 0 and rightval is not None:
                    pseudo_class = SimpleEQ0PComp

            pseudo = pseudo_class(self.lhs.scope,
                                  self._combined_expr(),
                                  pseudo_type='constraint',
                                  subtype=subtype,
                                  exprobject=self)

            self.pcomp_name = pseudo.name
            self.lhs.scope.add(pseudo.name, pseudo)
            getattr(self.lhs.scope,
                    pseudo.name).make_connections(self.lhs.scope, driver)

    def _combined_expr(self):
        """Given a constraint object, take the lhs, operator, and
        rhs and combine them into a single expression by moving rhs
        terms over to the lhs.  For example,
        for the constraint 'C1.x < C2.y + 7', return the expression
        'C1.x - C2.y - 7'.  Depending on the direction of the operator,
        the sign of the expression may be flipped.  The final form of
        the constraint, when evaluated, will be considered to be satisfied
        if it evaluates to a value <= 0.
        """
        scope = self.lhs.scope

        if self.comparator.startswith('>'):
            first = self.rhs.text
            second = self.lhs.text
        else:
            first = self.lhs.text
            second = self.rhs.text

        first_zero = False
        try:
            f = float(first)
        except Exception:
            pass
        else:
            if f == 0:
                first_zero = True

        second_zero = False
        try:
            f = float(second)
        except Exception:
            pass
        else:
            if f == 0:
                second_zero = True

        if first_zero:
            newexpr = "-(%s)" % second
        elif second_zero:
            newexpr = first
        else:
            newexpr = '%s-(%s)' % (first, second)

        return ExprEvaluator(newexpr, scope)

    def deactivate(self):
        """Remove this constraint from the dependency graph and remove
        its pseudocomp from the scoping object.
        """
        if self.pcomp_name:
            scope = self.lhs.scope
            try:
                pcomp = getattr(scope, self.pcomp_name)
            except AttributeError:
                pass
            else:
                scope.remove(pcomp.name)
            finally:
                self.pcomp_name = None

    def copy(self):
        """ Returns a copy of our self. """
        return Constraint(str(self.lhs),
                          self.comparator,
                          str(self.rhs),
                          scope=self.lhs.scope)

    def evaluate(self, scope):
        """Returns the value of the constraint as a sequence."""

        pcomp = getattr(scope, self.pcomp_name)
        val = pcomp.out0

        if isinstance(val, ndarray):
            return val.flatten()
        else:
            return [val]

    def get_referenced_compnames(self):
        """Returns a set of names of each component referenced by this
        constraint.
        """
        if isinstance(self.rhs, float):
            return self.lhs.get_referenced_compnames()
        else:
            return self.lhs.get_referenced_compnames().union(
                self.rhs.get_referenced_compnames())

    def get_referenced_varpaths(self, copy=True, refs=False):
        """Returns a set of names of each component referenced by this
        constraint.
        """
        if isinstance(self.rhs, float):
            return self.lhs.get_referenced_varpaths(copy=copy, refs=refs)
        else:
            return self.lhs.get_referenced_varpaths(
                copy=copy, refs=refs).union(
                    self.rhs.get_referenced_varpaths(copy=copy, refs=refs))

    def __str__(self):
        return ' '.join((str(self.lhs), self.comparator, str(self.rhs)))

    def __eq__(self, other):
        if not isinstance(other, Constraint):
            return False
        return (self.lhs, self.comparator, self.rhs) == \
               (other.lhs, other.comparator, other.rhs)
class Constraint2Sided(Constraint):
    """ Object that stores info for a double-sided constraint. """
    def __init__(self, lhs, center, rhs, comparator, scope):
        self.lhs = ExprEvaluator(lhs, scope=scope)
        unresolved_vars = self.lhs.get_unresolved()

        if unresolved_vars:
            msg = "Left hand side of constraint '{0}' has invalid variables {1}"
            expression = ' '.join((lhs, comparator, center, comparator, rhs))

            raise ExprEvaluator._invalid_expression_error(unresolved_vars,
                                                          expr=expression,
                                                          msg=msg)
        self.center = ExprEvaluator(center, scope=scope)
        unresolved_vars = self.center.get_unresolved()

        if unresolved_vars:
            msg = "Center of constraint '{0}' has invalid variables {1}"
            expression = ' '.join((lhs, comparator, center, comparator, rhs))

            raise ExprEvaluator._invalid_expression_error(unresolved_vars,
                                                          expr=expression,
                                                          msg=msg)
        self.rhs = ExprEvaluator(rhs, scope=scope)
        unresolved_vars = self.rhs.get_unresolved()

        if unresolved_vars:
            msg = "Right hand side of constraint '{0}' has invalid variables {1}"
            expression = ' '.join((lhs, comparator, center, comparator, rhs))

            raise ExprEvaluator._invalid_expression_error(unresolved_vars,
                                                          expr=expression,
                                                          msg=msg)
        self.comparator = comparator
        self.pcomp_name = None
        self._size = None

        # Linear flag: constraints are nonlinear by default
        self.linear = False

        self.low = self.lhs.evaluate()
        self.high = self.rhs.evaluate()

    def activate(self, driver):
        """Make this constraint active by creating the appropriate
        connections in the dependency graph.
        """
        if self.pcomp_name is None:

            scope = self.lhs.scope
            refs = list(self.center.ordered_refs())
            pseudo_class = PseudoComponent

            # look for a<var1<b
            if len(refs) == 1 and self.center.text == refs[0]:
                pseudo_class = SimpleEQ0PComp

            pseudo = pseudo_class(scope,
                                  self.center,
                                  pseudo_type='constraint',
                                  subtype='inequality',
                                  exprobject=self)

            self.pcomp_name = pseudo.name
            scope.add(pseudo.name, pseudo)
            getattr(scope, pseudo.name).make_connections(scope, driver)

    def _combined_expr(self):
        """Only need the center expression
        """
        return self.center

    def copy(self):
        """ Returns a copy of our self. """
        return Constraint2Sided(str(self.lhs),
                                str(self.center),
                                str(self.rhs),
                                self.comparator,
                                scope=self.lhs.scope)

    def get_referenced_compnames(self):
        """Returns a set of names of each component referenced by this
        constraint.
        """
        return self.center.get_referenced_compnames()

    def get_referenced_varpaths(self, copy=True, refs=False):
        """Returns a set of names of each component referenced by this
        constraint.
        """
        return self.center.get_referenced_varpaths(copy=copy, refs=refs)

    def __str__(self):
        return ' '.join(
            (str(self.lhs), str(self.center), str(self.rhs), self.comparator))

    def __eq__(self, other):
        if not isinstance(other, Constraint2Sided):
            return False
        return (self.lhs, self.center, self.comparator, self.rhs) == \
               (other.lhs, self.center, other.comparator, other.rhs)