class expression(ExpressionBase): __slots__ = ('_operators', '_n_opers', '_vars', '_params', '_floats') def __init__(self, expr=None): """ Parameters ---------- expr: expression """ if expr is not None: if expr._operators[-1] is not expr.last_node(): self._operators = expr.list_of_operators() else: self._operators = expr._operators else: self._operators = [] self._n_opers = len(self._operators) self._vars = None self._params = None self._floats = None def append_operator(self, oper): self._operators.append(oper) self._n_opers += 1 def last_node(self): """ Returns ------- last_node: Operator """ return self._operators[self._n_opers - 1] def list_of_operators(self): return self._operators[:self._n_opers] def is_leaf(self): return False def operators(self): return itertools.islice(self._operators, 0, self._n_opers) def _binary_operation_helper(self, other, cls): if type(other) in native_numeric_types: other = Float(other) new_operator = cls(self.last_node(), other.last_node()) expr = expression(self) for oper in other.operators(): expr.append_operator(oper) expr.append_operator(new_operator) return expr def _unary_operation_helper(self, cls): new_operator = cls(self.last_node()) expr = expression(self) expr.append_operator(new_operator) return expr def evaluate(self): val_dict = dict() for oper in self.operators(): oper.evaluate(val_dict) return val_dict[self.last_node()] def get_vars(self): if self._vars is None: self._collect_leaves() for i in self._vars: yield i def get_params(self): if self._params is None: self._collect_leaves() for i in self._params: yield i def get_floats(self): if self._floats is None: self._collect_leaves() for i in self._floats: yield i def _collect_leaves(self): self._vars = OrderedSet() self._params = OrderedSet() self._floats = OrderedSet() for oper in self.operators(): for operand in oper.operands(): if operand.is_leaf(): if operand.is_variable_type(): self._vars.add(operand) elif operand.is_parameter_type(): self._params.add(operand) elif operand.is_float_type(): self._floats.add(operand) elif operand.is_expression_type(): self._vars.update(operand.get_vars()) self._params.update(operand.get_params()) self._floats.update(operand.get_floats()) else: raise ValueError('operand type not recognized: ' + str(operand)) def get_leaves(self): if self._vars is None: self._collect_leaves() for i in self._vars: yield i for i in self._params: yield i for i in self._floats: yield i def _str(self): return str(self) def __str__(self): val_dict = dict() for oper in self.operators(): oper._str(val_dict) return val_dict[self.last_node()] def is_variable_type(self): return False def is_parameter_type(self): return False def is_float_type(self): return False def is_expression_type(self): return True def reverse_ad(self): val_dict = dict() der_dict = dict() for oper in self.operators(): oper.diff_up(val_dict, der_dict) der_dict[self.last_node()] = 1 for oper in reversed(self.list_of_operators()): oper.diff_down(val_dict, der_dict) return der_dict def reverse_sd(self): val_dict = dict() der_dict = dict() for oper in self.operators(): oper.diff_up_symbolic(val_dict, der_dict) der_dict[self.last_node()] = 1 for oper in reversed(self.list_of_operators()): oper.diff_down(val_dict, der_dict) return der_dict def is_relational(self): if type(self.last_node()) in {InequalityOperator}: return True return False def get_rpn(self, leaf_ndx_map): rpn_map = dict() for oper in self.operators(): oper.get_rpn(rpn_map, leaf_ndx_map) return rpn_map[self.last_node()]
def _register_conditional_constraint(self, con): ccon = self._evaluator.add_if_else_constraint() con._c_obj = ccon self._con_ccon_map[con] = ccon leaf_ndx_map = OrderedDict() referenced_vars = OrderedSet() referenced_params = OrderedSet() referenced_floats = OrderedSet() ndx = 0 derivs = list() for expr in con.expr._conditions: referenced_vars.update(expr.get_vars()) referenced_params.update(expr.get_params()) referenced_floats.update(expr.get_floats()) for expr in con.expr._exprs: referenced_vars.update(expr.get_vars()) referenced_params.update(expr.get_params()) referenced_floats.update(expr.get_floats()) for expr in con.expr._exprs: _deriv = expr.reverse_sd() derivs.append(_deriv) for v in referenced_vars: if v not in _deriv: _deriv[v] = Float(0) elif type(_deriv[v]) in native_numeric_types: _deriv[v] = Float(_deriv[v]) referenced_floats.update(_deriv[v].get_floats()) for v in referenced_vars: leaf_ndx_map[v] = ndx ndx += 1 cvar = self._increment_var(v) ccon.add_leaf(cvar) for v in referenced_params: leaf_ndx_map[v] = ndx ndx += 1 cvar = self._increment_param(v) ccon.add_leaf(cvar) for v in referenced_floats: leaf_ndx_map[v] = ndx ndx += 1 cvar = self._increment_float(v) ccon.add_leaf(cvar) for i in range(len(con.expr._conditions)): condition_rpn = con.expr._conditions[i].get_rpn(leaf_ndx_map) for term in condition_rpn: ccon.add_condition_rpn_term(term) fn_rpn = con.expr._exprs[i].get_rpn(leaf_ndx_map) for term in fn_rpn: ccon.add_fn_rpn_term(term) for v in referenced_vars: cvar = v._c_obj jac = derivs[i][v] jac_rpn = jac.get_rpn(leaf_ndx_map) for term in jac_rpn: ccon.add_jac_rpn_term(cvar, term) ccon.end_condition() self._vars_referenced_by_con[con] = referenced_vars self._params_referenced_by_con[con] = referenced_params self._floats_referenced_by_con[con] = referenced_floats