예제 #1
0
def simplify(expr):
    str_length = len(str(expr))
    while True:
        traversal.on_every_node(_simplify, expr)
        if len(str(expr)) == str_length:
            break
        str_length = len(str(expr))
    return expr
예제 #2
0
def get_atomic_odes(ode_system: ODESystem) -> List[AtomicODE]:
    """ Returns all of the AtomicODEs in the `ode_system`. """
    atomic_odes = []

    def append_atomic_odes(expr):
        if isinstance(expr, AtomicODE):
            atomic_odes.append(expr)

    traversal.on_every_node(append_atomic_odes, ode_system.dp)
    return atomic_odes
예제 #3
0
def numbers(e: Expression) -> Set[Number]:
    """ Returns all of the numbers in `e`. """
    return_value = set()

    def f(e: Expression):
        if isinstance(e, Number):
            return_value.add(e)

    traversal.on_every_node(f, e)
    return return_value
예제 #4
0
def variables(e: Expression) -> Set[Variable]:
    """
    Returns all of the Variables in `e`.
    This also provides a simple example of how client code uses traversal.on_every_node.
    """
    return_value = set()

    def f(e: Expression):
        if isinstance(e, Variable):
            return_value.add(e)

    traversal.on_every_node(f, e)
    return return_value
예제 #5
0
def generate_fresh_variable(e: Expression) -> Variable:
    """ Returns a Variable that does not occur in `e`. """
    vars_in_e = []

    def _v_map(e: Expression):
        if isinstance(e, Variable):
            vars_in_e.append(e)

    traversal.on_every_node(_v_map, e)

    v = Variable("fv")
    while v in vars_in_e:
        v = Variable(v.name + "1")
    return v
예제 #6
0
def compute_monomials(e: Expression, of: Set[Variable]) -> List[Term]:
    """
    Fins all of the monomials of `of` in `e`.
    :param e: The expression in which to search for monomials.
    :param of: The set of variables of the monomial; others are treated as constants
    :return: A list of all monomials.
    """
    assert len(of) > 0
    rv = []

    def _find_monomial(se: Expression):
        if is_monomial(se, of):
            rv.append(se)

    traversal.on_every_node(_find_monomial, e)
    return rv
예제 #7
0
def all_dots(e: Expression):
    """
    Returns the set of all Dots in the expression `e`.
    A dot is a
    """
    dots = set()

    def _dot_map(e: Expression):
        if isinstance(e, DotTerm):
            dots.add(e)
        elif isinstance(e, DotFormula):
            dots.add(e)

    try:
        traversal.on_every_node(_dot_map, e)
    except MatchError:
        raise Exception(e)

    return list(dots)
예제 #8
0
def replace_without_copy(what: Expression, repl: Expression,
                         target_input: Expression) -> Expression:
    """
    Side-effecting function that replaces all `what`s with `repl`s in the `target_input`.
    Will print out debugging messages if REPLACEMENT_LOGGING is set to true.
    :param what: The expression to replace.
    :param repl: The replacement expression.
    :param target_input: The expression in which the replacement should be made.
    :return: `target_input`.
    """
    assert isinstance(what, Expression)
    assert isinstance(repl,
                      Expression), f"Expected expression but found {repl}"
    assert isinstance(target_input, Expression)

    if REPLACEMENT_LOGGING:
        replacement_msg = "Replacing %s with %s in %s\n" % (what, repl,
                                                            target_input)
        logging.info(replacement_msg)

    constructor_stack = []

    def _push_fn(e: Expression):
        if e.is_referentially_eq(what):
            constructor_stack.append(repl)
        else:
            if isinstance(e, (Forall, Exists, Program, DifferentialProgram)):
                raise NotImplementedError(
                    "Substitution only defined for expressions without binding structure."
                )
            constructor_stack.append(e)

    traversal.on_every_node(_push_fn, target_input)
    if REPLACEMENT_LOGGING:
        logging.info(
            f"initial constructor stack for {target_input} is: {constructor_stack}"
        )
    assert target_input in constructor_stack
    assert len(constructor_stack) > 0

    arg_stack = []
    while len(constructor_stack) > 0:
        nxt = constructor_stack.pop()
        if REPLACEMENT_LOGGING:
            logging.info(
                f"Current constructor stack: {nxt}::{constructor_stack}")
            logging.info(f"Current arg stack: {arg_stack}")
        if isinstance(nxt, (Forall, Exists, PredApp)):
            raise NotImplementedError()
        elif nxt.is_referentially_eq(repl):
            arg_stack.append(nxt)
        elif isinstance(nxt, CompositeExpression):
            assert not isinstance(
                nxt, DotTerm), "not sure why this would happen; just checking."
            if nxt.arity() == 1:
                nxt.__init__(arg_stack.pop())
            elif nxt.arity() == 2:
                assert len(arg_stack) >= 2, (
                    replacement_msg +
                    "About to re-apply arity 2 operator (%s) but only have this arg stack:\n\t%s"
                    % (nxt, arg_stack))
                nxt.__init__(arg_stack.pop(), arg_stack.pop())
            else:
                raise MatchError(
                    replacement_msg +
                    "we now have longer arities that need to be handled.")
            arg_stack.append(nxt)
        else:
            arg_stack.append(nxt)

    return arg_stack.pop()