def simplify(expr): str_length = len(str(expr)) while True: traversal.on_every_node(_simplify, expr) if len(str(expr)) == str_length: break str_length = len(str(expr)) return expr
def get_atomic_odes(ode_system: ODESystem) -> List[AtomicODE]: """ Returns all of the AtomicODEs in the `ode_system`. """ atomic_odes = [] def append_atomic_odes(expr): if isinstance(expr, AtomicODE): atomic_odes.append(expr) traversal.on_every_node(append_atomic_odes, ode_system.dp) return atomic_odes
def numbers(e: Expression) -> Set[Number]: """ Returns all of the numbers in `e`. """ return_value = set() def f(e: Expression): if isinstance(e, Number): return_value.add(e) traversal.on_every_node(f, e) return return_value
def variables(e: Expression) -> Set[Variable]: """ Returns all of the Variables in `e`. This also provides a simple example of how client code uses traversal.on_every_node. """ return_value = set() def f(e: Expression): if isinstance(e, Variable): return_value.add(e) traversal.on_every_node(f, e) return return_value
def generate_fresh_variable(e: Expression) -> Variable: """ Returns a Variable that does not occur in `e`. """ vars_in_e = [] def _v_map(e: Expression): if isinstance(e, Variable): vars_in_e.append(e) traversal.on_every_node(_v_map, e) v = Variable("fv") while v in vars_in_e: v = Variable(v.name + "1") return v
def compute_monomials(e: Expression, of: Set[Variable]) -> List[Term]: """ Fins all of the monomials of `of` in `e`. :param e: The expression in which to search for monomials. :param of: The set of variables of the monomial; others are treated as constants :return: A list of all monomials. """ assert len(of) > 0 rv = [] def _find_monomial(se: Expression): if is_monomial(se, of): rv.append(se) traversal.on_every_node(_find_monomial, e) return rv
def all_dots(e: Expression): """ Returns the set of all Dots in the expression `e`. A dot is a """ dots = set() def _dot_map(e: Expression): if isinstance(e, DotTerm): dots.add(e) elif isinstance(e, DotFormula): dots.add(e) try: traversal.on_every_node(_dot_map, e) except MatchError: raise Exception(e) return list(dots)
def replace_without_copy(what: Expression, repl: Expression, target_input: Expression) -> Expression: """ Side-effecting function that replaces all `what`s with `repl`s in the `target_input`. Will print out debugging messages if REPLACEMENT_LOGGING is set to true. :param what: The expression to replace. :param repl: The replacement expression. :param target_input: The expression in which the replacement should be made. :return: `target_input`. """ assert isinstance(what, Expression) assert isinstance(repl, Expression), f"Expected expression but found {repl}" assert isinstance(target_input, Expression) if REPLACEMENT_LOGGING: replacement_msg = "Replacing %s with %s in %s\n" % (what, repl, target_input) logging.info(replacement_msg) constructor_stack = [] def _push_fn(e: Expression): if e.is_referentially_eq(what): constructor_stack.append(repl) else: if isinstance(e, (Forall, Exists, Program, DifferentialProgram)): raise NotImplementedError( "Substitution only defined for expressions without binding structure." ) constructor_stack.append(e) traversal.on_every_node(_push_fn, target_input) if REPLACEMENT_LOGGING: logging.info( f"initial constructor stack for {target_input} is: {constructor_stack}" ) assert target_input in constructor_stack assert len(constructor_stack) > 0 arg_stack = [] while len(constructor_stack) > 0: nxt = constructor_stack.pop() if REPLACEMENT_LOGGING: logging.info( f"Current constructor stack: {nxt}::{constructor_stack}") logging.info(f"Current arg stack: {arg_stack}") if isinstance(nxt, (Forall, Exists, PredApp)): raise NotImplementedError() elif nxt.is_referentially_eq(repl): arg_stack.append(nxt) elif isinstance(nxt, CompositeExpression): assert not isinstance( nxt, DotTerm), "not sure why this would happen; just checking." if nxt.arity() == 1: nxt.__init__(arg_stack.pop()) elif nxt.arity() == 2: assert len(arg_stack) >= 2, ( replacement_msg + "About to re-apply arity 2 operator (%s) but only have this arg stack:\n\t%s" % (nxt, arg_stack)) nxt.__init__(arg_stack.pop(), arg_stack.pop()) else: raise MatchError( replacement_msg + "we now have longer arities that need to be handled.") arg_stack.append(nxt) else: arg_stack.append(nxt) return arg_stack.pop()