def signature(self): """ Return a signature uniquely defining the ODE """ import hashlib def_list = [] for comp in self.components: if self != comp: def_list.append(str(comp)) # Sort wrt stringified states and parameters avoiding trouble with # random ordering of **kwargs def_list += sorted( [repr(state.param) for state in comp.full_states]) def_list += sorted( [repr(param.param) for param in comp.parameters]) def_list += [sympycode(expr.expr) for expr in comp.intermediates] # Sort state expressions wrt stringified state names def_list += [ sympycode(expr.expr) for expr in sorted( self.state_expressions, key=cmp_to_key( lambda o0, o1: cmp(str(o0.state), str(o1.state))), ) ] h = hashlib.sha1() h.update(";".join(def_list).encode("utf-8")) return h.hexdigest()
def _expect_state(self, state, allow_state_solution=False, only_local_states=False): """ Help function to check an argument which should be expected to be a state """ if allow_state_solution: allowed = (State, StateSolution) else: allowed = (State, ) if isinstance(state, AppliedUndef): name = sympycode(state) state = self.root.present_ode_objects.get(name) if state is None: error(f"{name} is not registered in this ODE") if only_local_states and not (state in self.states or (state in self.intermediates and allow_state_solution)): error(f"{name} is not registered in component {self.name}") check_arg(state, allowed, 0) if isinstance(state, State) and state.is_solved: error( "Cannot registered a state expression for a state " "which is registered solved.", ) return state
def __init__(self, state, expr, dependent=None): """ Create a StateDerivative Arguments --------- state : State The state for which the StateDerivative should apply expr : sympy.Basic The expression which the differetiation should be equal dependent : ODEObject If given the count of this StateDerivative will follow as a fractional count based on the count of the dependent object """ check_arg(state, State, 0, StateDerivative) sym = sp.Derivative(state.sym, state.time.sym) sym._assumptions["real"] = True sym._assumptions["imaginary"] = False sym._assumptions["commutative"] = True sym._assumptions["hermitian"] = True sym._assumptions["complex"] = True # Call base class constructor super(StateDerivative, self).__init__(sympycode(sym), state, expr, dependent) self._sym = sym
def add_derivative(self, der_expr, dep_var, expr, dependent=None): """ Add a derivative expression Arguments --------- der_expr : gotran.Expression, gotran.State, sympy.AppliedUndef The Expression or State which is differentiated dep_var : gotran.State, gotran.Time, gotran.Expression, sympy.AppliedUndef, sympy.Symbol The dependent variable expr : sympy.Basic The expression which the differetiation should be equal dependent : gotran.ODEObject If given the count of this expression will follow as a fractional count based on the count of the dependent object """ timer = Timer("Add derivatives") # noqa: F841 if isinstance(der_expr, AppliedUndef): name = sympycode(der_expr) der_expr = self.root.present_ode_objects.get(name) if der_expr is None: error(f"{name} is not registered in this ODE") if isinstance(dep_var, (AppliedUndef, sp.Symbol)): name = sympycode(dep_var) dep_var = self.root.present_ode_objects.get(name) if dep_var is None: error(f"{name} is not registered in this ODE") # Check if der_expr is a State if isinstance(der_expr, State): self._expect_state(der_expr) obj = StateDerivative(der_expr, expr, dependent) else: # Create a DerivativeExpression in the present component obj = DerivativeExpression(der_expr, dep_var, expr, dependent) self._register_component_object(obj, dependent) return obj.sym
def _args_str(self): """ Return a formatted str of __init__ arguments """ return "{0}, {1}, {2}".format( repr(self._to_state), repr(self._from_state), sympycode(self.expr), )
def __init__(self, state, expr, dependent=None): """ Create a StateSolution Arguments --------- state : State The state that is being solved for expr : sympy.Basic The expression that should equal 0 and which solves the state dependent : ODEObject If given the count of this StateSolution will follow as a fractional count based on the count of the dependent object """ check_arg(state, State, 0, StateSolution) super(StateSolution, self).__init__(sympycode(state.sym), expr) # Flag solved state state._is_solved = True self._state = state
def __init__(self, der_expr, dep_var, expr, dependent=None): """ Create a DerivativeExpression Arguments --------- der_expr : Expression, State The Expression or State which is differentiated dep_var : State, Time, Expression The dependent variable expr : sympy.Basic The expression which the differetiation should be equal dependent : ODEObject If given the count of this DerivativeExpression will follow as a fractional count based on the count of the dependent object """ check_arg(der_expr, Expression, 0, DerivativeExpression) check_arg(dep_var, (State, Expression, Time), 1, DerivativeExpression) # Check that the der_expr is dependent on var if dep_var.sym not in der_expr.sym.args: error( "Cannot create a DerivativeExpression as {0} is not " "dependent on {1}".format(der_expr, dep_var), ) der_sym = sp.Derivative(der_expr.sym, dep_var.sym) self._der_expr = der_expr self._dep_var = dep_var super(DerivativeExpression, self).__init__(sympycode(der_sym), expr, dependent) self._sym = sp.Derivative(der_expr.sym, dep_var.sym) self._sym._assumptions["real"] = True self._sym._assumptions["commutative"] = True self._sym._assumptions["imaginary"] = False self._sym._assumptions["hermitian"] = True self._sym._assumptions["complex"] = True
def register_ode_object(self, obj, comp, dependent=None): """ Register an ODE object in the root ODEComponent """ from modelparameters.sympytools import symbols_from_expr if self._is_finalized_ode and isinstance(obj, StateExpression): error("Cannot register a StateExpression, the ODE is finalized") # Check for existing object in the ODE dup_obj = self.present_ode_objects.get(obj.name) # If object with same name is already registered in the ode we # need to figure out what to do if dup_obj: try: dup_comp = self.object_component[dup_obj] except KeyError: dup_comp = None # If a state is substituted by a state solution if isinstance(dup_obj, State) and isinstance(obj, StateSolution): debug(f"Reduce state '{dup_obj}' to {obj.expr}") # If duplicated object is an ODE Parameter and the added object is # either a State or a Parameter we replace the Parameter. elif (isinstance(dup_obj, Parameter) and dup_comp == self and comp != self and isinstance(obj, (State, Parameter))): timer = Timer("Replace objects") # noqa: F841 # Remove the object self.ode_objects.remove(dup_obj) # FIXME: Do we need to recreate all expression the objects is used in? # Replace the object from the object_used_in dict and update # the correponding expressions subs = {dup_obj.sym: obj.sym} subs = {} # Recursively replace object dependencies self._replace_object(dup_obj, obj, subs) # for expr in self.object_used_in[dup_obj]: # updated_expr = recreate_expression(expr, subs) # self.object_used_in[obj].add(updated_expr) # # # Exchange and update the dependencies # self.expression_dependencies[expr].remove(dup_obj) # self.expression_dependencies[expr].add(obj) # # # FIXME: Do not remove the dependencies # #self.expression_dependencies[updated_expr] = \ # # self.expression_dependencies.pop(expr) # self.expression_dependencies[updated_expr] = \ # self.expression_dependencies[expr] # # # Find the index of old expression and exchange it with updated # old_comp = self.object_component[expr] # ind = old_comp.ode_objects.index(expr) # old_comp.ode_objects[ind] = updated_expr # ## Remove information about the replaced objects # self.object_used_in.pop(dup_obj) # If duplicated object is an ODE Parameter and the added # object is an Intermediate we raise an error. elif (isinstance(dup_obj, Parameter) and dup_comp == self and isinstance(obj, Expression)): error( "Cannot replace an ODE parameter with an Expression, " "only with Parameters and States.", ) # If State, Parameter or DerivativeExpression we always raise an error elif any( isinstance( oo, ( State, Parameter, Time, Dt, DerivativeExpression, AlgebraicExpression, StateSolution, ), ) for oo in [dup_obj, obj]): error( "Cannot register {0}. A {1} with name '{2}' is " "already registered in this ODE.".format( type(obj).__name__, type(dup_obj).__name__, dup_obj.name, ), ) else: # Sanity check that both obj and dup_obj are Expressions assert all( isinstance(oo, (Expression)) for oo in [dup_obj, obj]) # Get list of duplicated objects or an empy list dup_objects = self.duplicated_expressions[obj.name] if len(dup_objects) == 0: dup_objects.append(dup_obj) dup_objects.append(obj) # Update global information about ode object self.present_ode_objects[obj.name] = obj self.object_component[obj] = comp self.ns.update({obj.name: obj.sym}) # If Expression if isinstance(obj, Expression): # Append the name to the list of all ordered components with # expressions. If the ODE is finalized we do not update components if not self._is_finalized_ode: self._handle_expr_component(comp, obj) # Expand and add any derivatives in the expressions expression_added = False replace_dict = {} derivative_expression_list = list(obj.expr.atoms(sp.Derivative)) derivative_expression_list.sort(key=lambda e: e.sort_key()) for der_expr in derivative_expression_list: expression_added |= self._expand_single_derivative( comp, obj, der_expr, replace_dict, dependent, ) # If expressions need to be re-created if replace_dict: obj.replace_expr(replace_dict) # If any expression was added we need to bump the count of the ODEObject if expression_added: obj._recount(dependent=dependent) # Add dependencies between the last registered comment and # expressions so they are carried over in Code components if comp._local_comments: self.object_used_in[comp._local_comments[-1]].add(obj) self.expression_dependencies[obj].add(comp._local_comments[-1]) # Get expression dependencies for sym in symbols_from_expr(obj.expr, include_derivatives=True): dep_obj = self.present_ode_objects[sympycode(sym)] if dep_obj is None: error( "The symbol '{0}' is not declared within the '{1}' " "ODE.".format(sym, self.name), ) # Store object dependencies self.expression_dependencies[obj].add(dep_obj) self.object_used_in[dep_obj].add(obj) # If the expression is a StateSolution the state cannot have # been used previously if isinstance(obj, StateSolution) and self.object_used_in.get( obj.state): used_in = self.object_used_in.get(obj.state) error( "A state solution cannot have been used in " "any previous expressions. {0} is used in: {1}".format( obj.state, used_in, ), )
def save(self, basename=None): """ Save ODE to file Arguments --------- basename : str (optional) The basename of the file which the ode will be saved to, if not given the basename will be the same as the name of the ode. """ timer = Timer("Save " + self.name) # noqa: F841 if not self._is_finalized_ode: error("ODE need to be finalized to be saved to file.") lines = ["# Saved Gotran model"] comp_names = dict() basename = basename or self.name for comp in self.components: if comp == self: comp_name = "" else: present_comp = comp comps = [present_comp.name] while present_comp.parent != self: present_comp = present_comp.parent comps.append(present_comp.name) comp_name = ", ".join(f'"{name}"' for name in reversed(comps)) comp_names[comp] = comp_name states = [ f"{obj.name}={obj.param.repr(include_name=False)}," for obj in comp.ode_objects if isinstance(obj, State) ] parameters = [ f"{obj.name}={obj.param.repr(include_name=False)}," for obj in comp.ode_objects if isinstance(obj, Parameter) ] if states: lines.append("") if comp_name: lines.append(f"states({comp_name},") else: lines.append(f"states({states.pop(0)}") for state_code in states: lines.append(" " + state_code) lines[-1] = lines[-1][:-1] + ")" if parameters: lines.append("") if comp_name: lines.append(f"parameters({comp_name},") else: lines.append(f"parameters({parameters.pop(0)}") for param_code in parameters: lines.append(" " + param_code) lines[-1] = lines[-1][:-1] + ")" # Iterate over all components for comp_name in self.all_expr_components_ordered: comp = self.all_components[comp_name] comp_comment = f"Expressions for the {comp.name} component" # Iterate over all objects of the component and save only expressions # and comments for obj in comp.ode_objects: # If saving an expression if isinstance(obj, Expression): # If the component is a Markov model if comp.rates: # Do not save State derivatives if isinstance(obj, StateDerivative): continue # Save rate expressions slightly different elif isinstance(obj, RateExpression): lines.append( "rates[{0}, {1}] = {2}".format( sympycode(obj.states[0]), sympycode(obj.states[1]), sympycode(obj.expr), ), ) continue # All other Expressions lines.append(f"{obj.name} = {sympycode(obj.expr)}") # If saving a comment elif isinstance(obj, Comment): # If comment is component comment if str(obj) == comp_comment: lines.append("") comp_name = (comp_names[comp] if comp_names[comp] else f'"{basename}"') lines.append(f"expressions({comp_name})") # Just add the comment else: lines.append("") lines.append(f'comment("{obj}")') lines.append("") # Use Python code generator to indent outputted code # Write to file from gotran.codegeneration.codegenerators import PythonCodeGenerator with open(basename + ".ode", "w") as f: f.write("\n".join( PythonCodeGenerator.indent_and_split_lines(lines)))
def _expand_single_derivative(self, comp, obj, der_expr, replace_dict, dependent): """ Expand a single derivative and register it as new derivative expression Returns True if an expression was actually added Populate replace dict with a replacement for the derivative if it is trivial """ # Try accessing already registered derivative expressions der_expr_obj = self.present_ode_objects.get(sympycode(der_expr)) # If excist continue if der_expr_obj: return False # Try expand the given derivative if it is directly expandable just # add a replacement for the derivative result der_result = der_expr.args[0].diff(der_expr.args[1]) if not der_result.atoms(sp.Derivative): replace_dict[der_expr] = der_result return False if not isinstance(der_expr.args[0], AppliedUndef): error( "Can only register Derivatives of allready registered " "Expressions. Got: {0}".format(sympycode(der_expr.args[0])), ) if not isinstance(der_expr.args[1], (AppliedUndef, sp.Symbol)): error( "Can only register Derivatives with a single dependent " "variabe. Got: {0}".format(sympycode(der_expr.args[1])), ) # Get the expr and dependent variable objects expr_obj = self.present_ode_objects[sympycode(der_expr.args[0])] var_obj = self.present_ode_objects[sympycode(der_expr.args[1])] # If the dependent variable is time and the expression is a state # variable we raise an error as the user should already have created # the expression. if isinstance(expr_obj, State) and var_obj == self._time: error( "The expression {0} is dependent on the state " "derivative of {1} which is not registered in this ODE.". format( obj, expr_obj, ), ) # If we get a Derivative(expr, t) we issue an error # if isinstance(expr_obj, Expression) and var_obj == self._time: # error("All derivative expressions of registered expressions "\ # "need to be expanded with respect to time. Use "\ # "expr.diff(t) instead of Derivative(expr, t) ") if not isinstance(expr_obj, Expression): error( "Can only differentiate expressions or states. Got {0} as " "the derivative expression.".format(expr_obj), ) # Expand derivative and see if it is trivial der_result = expr_obj.expr.diff(var_obj.sym) # If derivative result are trival we substitute it if (der_result.is_number or isinstance(der_result, (sp.Symbol, AppliedUndef)) or (isinstance(der_result, (sp.Mul, sp.Pow, sp.Add)) and len(der_result.args) == 2 and all( isinstance(arg, (sp.Number, sp.Symbol, AppliedUndef)) for arg in der_result.args))): replace_dict[der_expr] = der_result return False # Store expression comp.add_derivative(expr_obj, var_obj, der_result, dependent) return True
def __init__( self, factorized, function_name="forward_backward_subst", result_name="dx", residual_name="F", params=None, ): """ Create a JacobianForwardBackwardSubstComponent Arguments --------- factorized : gotran.FactorizedJacobianComponent The factorized jacobian of the ODE function_name : str The name of the function which should be generated result_name : str The name of the result (increment) residual_name : str The name of the residual params : dict Parameters determining how the code should be generated """ timer = Timer( "Computing forward backward substituion component") # noqa: F841 check_arg(factorized, FactorizedJacobianComponent) jacobian_name = list(factorized.shapes.keys())[0] descr = ("Symbolically forward backward substitute linear system " "of {0} ODE".format(factorized.root)) super(ForwardBackwardSubstitutionComponent, self).__init__( "ForwardBackwardSubst", factorized.root, function_name, descr, params=params, use_default_arguments=False, additional_arguments=[residual_name], ) self.add_comment( f"Forward backward substituting factorized linear system {self.root.name}", ) # Recreate jacobian using only sympy Symbols jac_orig = factorized.factorized_jacobian # Size of system n = jac_orig.rows jac = sp.Matrix(n, n, lambda i, j: sp.S.Zero) for i in range(n): for j in range(n): # print jac_orig[i,j] if not jac_orig[i, j].is_zero: name = sympycode(jac_orig[i, j]) jac[i, j] = sp.Symbol( name, real=True, imaginary=False, commutative=True, hermitian=True, complex=True, ) print(jac[i, j]) self.shapes[jacobian_name] = (n, n) self.shapes[residual_name] = (n, ) self.shapes[result_name] = (n, ) F = [] dx = [] # forward substitution, all diag entries are scaled to 1 for i in range(n): F.append(self.add_indexed_object(residual_name, i)) dx.append(self.add_indexed_expression(result_name, i, F[i])) for j in range(i): if jac[i, j].is_zero: continue dx[i] = self.add_indexed_expression( result_name, i, dx[i] - dx[j] * jac[i, j], ) # backward substitution for i in range(n - 1, -1, -1): for j in range(i + 1, n): if jac[i, j].is_zero: continue dx[i] = self.add_indexed_expression( result_name, i, dx[i] - dx[j] * jac[i, j], ) dx[i] = self.add_indexed_expression(result_name, i, dx[i] / jac[i, i]) # No need to call recreate body expressions self.body_expressions = [ obj for obj in self.ode_objects if isinstance(obj, (IndexedExpression, Comment)) ] self.results = [result_name] self.used_states = set() self.used_parameters = set()
def __init__(self, jacobian, function_name="lu_factorize", params=None): """ Create a FactorizedJacobianComponent Arguments --------- jacobian : gotran.JacobianComponent The Jacobian of the ODE function_name : str The name of the function which should be generated params : dict Parameters determining how the code should be generated """ timer = Timer("Computing factorization of jacobian") # noqa: F841 check_arg(jacobian, JacobianComponent) descr = f"Symbolically factorize the jacobian of the {jacobian.root} ODE" super(FactorizedJacobianComponent, self).__init__( "FactorizedJacobian", jacobian.root, function_name, descr, params=params, use_default_arguments=False, additional_arguments=jacobian.results, ) self.add_comment(f"Factorizing jacobian of {self.root.name}") jacobian_name = jacobian.results[0] # Recreate jacobian using only sympy Symbols jac_orig = jacobian.jacobian # Size of system n = jac_orig.rows jac = sp.Matrix(n, n, lambda i, j: sp.S.Zero) for i in range(n): for j in range(n): # print jac_orig[i,j] if not jac_orig[i, j].is_zero: name = sympycode(jac_orig[i, j]) jac[i, j] = sp.Symbol( name, real=True, imaginary=False, commutative=True, hermitian=True, complex=True, ) print(jac[i, j]) p = [] self.shapes[jacobian_name] = (n, n) def add_intermediate_if_changed(jac, jac_ij, i, j): # If item has changed if jac_ij != jac[i, j]: print("jac", i, j, jac_ij) jac[i, j] = self.add_indexed_expression(jacobian_name, (i, j), jac_ij) # Do the factorization for j in range(n): for i in range(j): # Get sympy expr of A_ij jac_ij = jac[i, j] # Build sympy expression for k in range(i): jac_ij -= jac[i, k] * jac[k, j] add_intermediate_if_changed(jac, jac_ij, i, j) pivot = -1 for i in range(j, n): # Get sympy expr of A_ij jac_ij = jac[i, j] # Build sympy expression for k in range(j): jac_ij -= jac[i, k] * jac[k, j] add_intermediate_if_changed(jac, jac_ij, i, j) # find the first non-zero pivot, includes any expression if pivot == -1 and jac[i, j]: pivot = i if pivot < 0: # this result is based on iszerofunc's analysis of the # possible pivots, so even though the element may not be # strictly zero, the supplied iszerofunc's evaluation gave # True error("No nonzero pivot found; symbolic inversion failed.") if pivot != j: # row must be swapped jac.row_swap(pivot, j) p.append([pivot, j]) print("Pivoting!!") # Scale with diagonal if not jac[j, j]: error("Diagonal element of the jacobian is zero. " "Inversion failed") scale = 1 / jac[j, j] for i in range(j + 1, n): # Get sympy expr of A_ij jac_ij = jac[i, j] jac_ij *= scale add_intermediate_if_changed(jac, jac_ij, i, j) # Store factorized jacobian self.factorized_jacobian = jac self.num_nonzero = sum(not jac[i, j].is_zero for i in range(n) for j in range(n)) # No need to call recreate body expressions self.body_expressions = self.ode_objects self.used_states = set() self.used_parameters = set()
def __init__(self, name, expr, dependent=None): """ Create an Expression with an associated name Arguments --------- name : str The name of the Expression expr : sympy.Basic The expression dependent : ODEObject If given the count of this Expression will follow as a fractional count based on the count of the dependent object """ from modelparameters.sympytools import symbols_from_expr # Check arguments check_arg(expr, scalars + (sp.Basic,), 1, Expression) expr = sp.sympify(expr) # Deal with Subs in sympy expression for sub_expr in expr.atoms(sp.Subs): # deal with one Subs at a time subs = dict( (key, value) for key, value in zip(sub_expr.variables, sub_expr.point) ) expr = expr.subs(sub_expr, sub_expr.expr.xreplace(subs)) # Deal with im and re im_exprs = expr.atoms(sp.im) re_exprs = expr.atoms(sp.re) if im_exprs or re_exprs: replace_dict = {} for im_expr in im_exprs: replace_dict[im_expr] = sp.S.Zero for re_expr in re_exprs: replace_dict[re_expr] = re_expr.args[0] expr = expr.xreplace(replace_dict) if not symbols_from_expr(expr, include_numbers=True): error( "expected the expression to contain at least one " "Symbol or Number.", ) # Call super class with expression as the "value" super(Expression, self).__init__(name, expr, dependent) # Collect dependent symbols dependent = tuple( sorted( symbols_from_expr(expr), key=cmp_to_key(lambda a, b: cmp(sympycode(a), sympycode(b))), ), ) if dependent: self._sym = self._param.sym(*dependent) self._sym._assumptions["real"] = True self._sym._assumptions["commutative"] = True self._sym._assumptions["imaginary"] = False self._sym._assumptions["hermitian"] = True self._sym._assumptions["complex"] = True else: self._sym = self.param.sym self._dependent = set(dependent)
def _recreate_body(self, body_expressions, **results): """ Create body expressions based on the given result_expressions In this method are all expressions replaced with something that should be used to generate code. The parameters in: parameters["generation"]["code"] decides how parameters, states, body expressions and indexed expressions are represented. """ if not (results or body_expressions): return for result_name, result_expressions in list(results.items()): check_kwarg( result_expressions, result_name, list, context=CodeComponent._recreate_body, itemtypes=(Expression, Comment), ) # Extract all result expressions result_expressions = sum(list(results.values()), []) # A map between result expression and result name result_names = dict( (result_expr, result_name) for result_name, result_exprs in list(results.items()) for result_expr in result_exprs ) timer = Timer(f"Recreate body expressions for {self.name}") # noqa: F841 # Initialize the replace_dictionaries replace_dict = self.param_state_replace_dict der_replace_dict = {} # Get a copy of the map of where objects are used in and their # present dependencies so any updates done in these dictionaries does not # affect the original dicts object_used_in = defaultdict(set) for expr, used in list(self.root.object_used_in.items()): object_used_in[expr].update(used) expression_dependencies = defaultdict(set) for expr, deps in list(self.root.expression_dependencies.items()): expression_dependencies[expr].update(deps) # Get body parameters body_repr = self._params["body"]["representation"] optimize_exprs = self._params["body"]["optimize_exprs"] # Set body related variables if the body should be represented by an array if "array" in body_repr: body_name = self._params["body"]["array_name"] available_indices = deque() max_index = -1 body_ind = 0 index_available_at = defaultdict(list) if body_name == result_name: error("body and result cannot have the same name.") # Initiate shapes with inf self.shapes[body_name] = (float("inf"),) # Iterate over body expressions and recreate the different expressions # according to state, parameters, body and result expressions replaced_expr_map = OrderedDict() new_body_expressions = [] present_ode_objects = dict( (state.name, state) for state in self.root.full_states ) present_ode_objects.update( (param.name, param) for param in self.root.parameters ) old_present_ode_objects = present_ode_objects.copy() def store_expressions(expr, new_expr): "Help function to store new expressions" timer = Timer( # noqa: F841 f"Store expression while recreating body of {self.name}", ) # noqa: F841 # Update sym replace dict if isinstance(expr, Derivatives): der_replace_dict[expr.sym] = new_expr.sym else: replace_dict[expr.sym] = new_expr.sym # Store the new expression for later references present_ode_objects[expr.name] = new_expr replaced_expr_map[expr] = new_expr # Append the new expression new_body_expressions.append(new_expr) # Update dependency information if expr in object_used_in: for dep in object_used_in[expr]: if dep in expression_dependencies: expression_dependencies[dep].remove(expr) expression_dependencies[dep].add(new_expr) object_used_in[new_expr] = object_used_in.pop(expr) if expr in expression_dependencies: expression_dependencies[new_expr] = expression_dependencies.pop(expr) self.add_comment("Recreated body expressions") # The main iteration over all body_expressions for expr in body_expressions: # 1) Comments if isinstance(expr, Comment): new_body_expressions.append(expr) continue assert isinstance(expr, Expression) # 2) Check for expression optimizations if not (optimize_exprs == "none" or expr in result_expressions): timer_opt = Timer( # noqa: F841 f"Handle expression optimization for {self.name}", ) # noqa: F841 # If expr is just a number we exchange the expression with the # number if "numerals" in optimize_exprs and isinstance(expr.expr, sp.Number): replace_dict[expr.sym] = expr.expr # Remove information about this expr beeing used for dep in object_used_in[expr]: expression_dependencies[dep].remove(expr) object_used_in.pop(expr) continue # If the expr is just a symbol (symbol multiplied with a scalar) # we exchange the expression with the sympy expressions elif "symbols" in optimize_exprs and ( isinstance(expr.expr, (sp.Symbol, AppliedUndef)) or isinstance(expr.expr, sp.Mul) and len(expr.expr.args) == 2 and isinstance(expr.expr.args[1], (sp.Symbol, AppliedUndef)) and expr.expr.args[0].is_number ): # Add a replace rule based on the stored sympy expression sympy_expr = expr.expr.xreplace(der_replace_dict).xreplace( replace_dict, ) if isinstance(expr.sym, sp.Derivative): der_replace_dict[expr.sym] = sympy_expr else: replace_dict[expr.sym] = sympy_expr # Get exchanged repr if isinstance(expr.expr, (sp.Symbol, AppliedUndef)): name = sympycode(expr.expr) else: name = sympycode(expr.expr.args[1]) dep_expr = present_ode_objects[name] # If using reused body expressions we need to update the # index information so that the index previously available # for this expressions gets available at the last expressions # the present expression is used in. if ( isinstance(dep_expr, IndexedExpression) and dep_expr.basename == body_name and "reused" in body_repr ): ind = dep_expr.indices[0] # Remove available index information dep_used_in = sorted(object_used_in[dep_expr]) for used_expr in dep_used_in: if ind in index_available_at[used_expr]: index_available_at[used_expr].remove(ind) # Update with new indices all_used_in = object_used_in[expr].copy() all_used_in.update(dep_used_in) for used_expr in sorted(all_used_in, reverse=True): if used_expr in body_expressions: index_available_at[used_expr].append(ind) break # Update information about this expr beeing used for dep in object_used_in[expr]: expression_dependencies[dep].remove(expr) expression_dependencies[dep].add(dep_expr) object_used_in.pop(expr) continue del timer_opt # 3) General operations for all Expressions that are kept # Before we process the expression we check if any indices gets # available with the expr (Only applies for the "reused" option for # body_repr.) if "reused" in body_repr: # Check if any indices are available at this expression ind available_indices.extend(index_available_at[expr]) # Store a map of old name this will preserve the ordering of # expressions with the same name, similar to how this is treated in # the actual ODE. present_ode_objects[expr.name] = expr old_present_ode_objects[expr.name] = expr # 4) Handle result expression if expr in result_expressions: timer_result = Timer( # noqa: F841 f"Handle result expressions for {self.name}", ) # noqa: F841 # Get the result name result_name = result_names[expr] # If the expression is an IndexedExpression with the same basename # as the result name we just recreate it if ( isinstance(expr, IndexedExpression) or isinstance(expr, StateIndexedExpression) or isinstance(expr, ParameterIndexedExpression) ) and result_name == expr.basename: new_expr = recreate_expression(expr, der_replace_dict, replace_dict) # Not an indexed expression else: # Get index based on the original ordering index = (results[result_name].index(expr),) # Create the IndexedExpression # NOTE: First replace any derivative expression replaces, then state and # NOTE: params if isinstance(expr, StateDerivative): new_expr = StateIndexedExpression( result_name, index, expr.expr.xreplace(der_replace_dict).xreplace(replace_dict), expr.state, (len(results[result_name]),), array_params=self._params.array, ) else: new_expr = IndexedExpression( result_name, index, expr.expr.xreplace(der_replace_dict).xreplace(replace_dict), (len(results[result_name]),), array_params=self._params.array, ) if new_expr.basename not in self.indexed_map: self.indexed_map[new_expr.basename] = OrderedDict() self.indexed_map[new_expr.basename][expr] = new_expr # Copy counter from old expression so it sort properly new_expr._recount(expr._count) # Store the expressions store_expressions(expr, new_expr) del timer_result # 4) Handle indexed expression # All indexed expressions are just kept but recreated with updated # sympy expressions elif isinstance(expr, IndexedExpression): timer_indexed = Timer( # noqa: F841 f"Handle indexed expressions for {self.name}", ) # noqa: F841 new_expr = recreate_expression(expr, der_replace_dict, replace_dict) # Store the expressions store_expressions(expr, new_expr) del timer_indexed # 5) If replacing all body exressions with an indexed expression elif "array" in body_repr: timer_body = Timer( # noqa: F841 f"Handle body expressions for {self.name}", ) # noqa: F841 # 5a) If we reuse array indices if "reused" in body_repr: if available_indices: ind = available_indices.popleft() else: max_index += 1 ind = max_index # Check when present ind gets available again for used_expr in sorted(object_used_in[expr], reverse=True): if used_expr in body_expressions: index_available_at[used_expr].append(ind) break else: warning("SHOULD NOT COME HERE!") # 5b) No reuse of array indices. Here each index corresponds to # a distinct body expression else: ind = body_ind # Increase body_ind body_ind += 1 # Create the IndexedExpression new_expr = IndexedExpression( body_name, ind, expr.expr.xreplace(der_replace_dict).xreplace(replace_dict), array_params=self._params.array, enum=expr.name, ) if body_name not in self.indexed_map: self.indexed_map[body_name] = OrderedDict() self.indexed_map[body_name][expr] = new_expr # Copy counter from old expression so they sort properly new_expr._recount(expr._count) # Store the expressions store_expressions(expr, new_expr) del timer_body # 6) If the expression is just an ordinary body expression and we # are using named representation of body else: timer_expr = Timer(f"Handle expressions for {self.name}") # noqa: F841 # If the expression is a state derivative we need to add a # replacement for the Derivative symbol if isinstance(expr, StateDerivative): new_expr = Intermediate( sympycode(expr.sym), expr.expr.xreplace(der_replace_dict).xreplace(replace_dict), ) new_expr._recount(expr._count) else: new_expr = recreate_expression(expr, der_replace_dict, replace_dict) del timer_expr # Store the expressions store_expressions(expr, new_expr) # Store indices for any added arrays if "reused_array" == body_repr: if max_index > -1: self.shapes[body_name] = (max_index + 1,) else: self.shapes.pop(body_name) elif "array" == body_repr: if body_ind > 0: self.shapes[body_name] = (body_ind,) else: self.shapes.pop(body_name) # Store the shape of the added result expressions for result_name, result_expressions in list(results.items()): if result_name not in self.shapes: self.shapes[result_name] = (len(result_expressions),) return new_body_expressions
def _body_from_cse(self, **results): timer = Timer(f"Compute common sub expressions for {self.name}") # noqa: F841 ( orig_result_expressions, result_names, expanded_result_exprs, ) = self._expanded_result_expressions(**results) state_offset = self._params["states"]["add_offset"] # Collect results and body_expressions body_expressions = [] new_results = defaultdict(list) might_take_time = len(orig_result_expressions) >= 40 if might_take_time: info( "Computing common sub expressions for {0}. Might take " "some time...".format(self.name), ) sys.stdout.flush() # Call sympy common sub expression reduction cse_exprs, cse_result_exprs = cse( expanded_result_exprs, symbols=sp.numbered_symbols("cse_"), optimizations=[], ) # Map the cse_expr to an OrderedDict cse_exprs = OrderedDict(cse_expr for cse_expr in cse_exprs) # Extract the symbols into a set for fast comparison cse_syms = set((sym for sym in cse_exprs)) # Create maps between cse_expr and result expressions trying # to optimized the code by weaving in the result expressions # in between the cse_expr # A map between result expr and name and indices so we can # construct IndexedExpressions result_expr_map = defaultdict(list) # A map between last cse_expr used in a particular result expr # so that we can put the result expression right after the # last cse_expr it uses. last_cse_expr_used_in_result_expr = defaultdict(list) # Result expressions that does not contain any cse_sym result_expr_without_cse_syms = [] # A map between cse_sym and its substitutes cse_subs = {} for ind, (orig_result_expr, result_expr) in enumerate( zip(orig_result_expressions, cse_result_exprs), ): # Collect information so that we can recreate the result # expression from result_expr_map[result_expr].append( ( result_names[orig_result_expr], orig_result_expr.indices if isinstance(orig_result_expr, IndexedExpression) else ind, ), ) # If result_expr does not contain any cse_sym if not any(cse_sym in cse_syms for cse_sym in result_expr.atoms()): result_expr_without_cse_syms.append(result_expr) else: # Get last cse_sym used in result expression last_cse_sym = sorted( (cse_sym for cse_sym in result_expr.atoms() if cse_sym in cse_syms), key=cmp_to_key(lambda a, b: cmp(int(a.name[4:]), int(b.name[4:]))), )[-1] if result_expr not in last_cse_expr_used_in_result_expr[last_cse_sym]: last_cse_expr_used_in_result_expr[last_cse_sym].append(result_expr) debug( "Found {0} result expressions without any cse_syms.".format( len(result_expr_without_cse_syms), ), ) # print "" # print "LAST cse_syms:", last_cse_expr_used_in_result_expr.keys() cse_cnt = 0 atoms = [state.sym for state in self.root.full_states] atoms.extend(param.sym for param in self.root.parameters) # Collecte what states and parameters has been used used_states = set() used_parameters = set() self.add_comment( "Common sub expressions for the body and the " "result expressions", ) body_expressions.append(self.ode_objects[-1]) # Register the common sub expressions as Intermediates for cse_sym, expr in list(cse_exprs.items()): # print cse_sym, expr # If the expression is just one of the atoms of the ODE we # skip the cse expressions but add a subs for the atom We # also skip Relationals and Piecewise as the type checking # in Piecewise otherwise kicks in and destroys things for # us. if expr in atoms or isinstance( expr, (sp.Piecewise, sp.relational.Relational, sp.relational.Boolean), ): cse_subs[cse_sym] = expr.xreplace(cse_subs) else: # Add body expression as an intermediate expression sym = self.add_intermediate(f"cse_{cse_cnt}", expr.xreplace(cse_subs)) obj = self.ode_objects.get(sympycode(sym)) for dep in self.root.expression_dependencies[obj]: if isinstance(dep, State): used_states.add(dep) elif isinstance(dep, Parameter): used_parameters.add(dep) cse_subs[cse_sym] = sym cse_cnt += 1 body_expressions.append(obj) # Check if we should add a result expressions if last_cse_expr_used_in_result_expr[cse_sym]: # Iterate over all registered result expr for this cse_sym for result_expr in last_cse_expr_used_in_result_expr.pop(cse_sym): for result_name, indices in result_expr_map[result_expr]: # Replace pure state and param expressions # print cse_subs, result_expr exp_expr = result_expr.xreplace(cse_subs) sym = self.add_indexed_expression( result_name, indices, exp_expr, add_offset=state_offset, ) expr = self.ode_objects.get(sympycode(sym)) for dep in self.root.expression_dependencies[expr]: if isinstance(dep, State): used_states.add(dep) elif isinstance(dep, Parameter): used_parameters.add(dep) # Register the new result expression new_results[result_name].append(expr) body_expressions.append(expr) if might_take_time: info(" done") # Sort used state, parameters and expr self.used_states = sorted(used_states) self.used_parameters = sorted(used_parameters) return new_results, body_expressions
def test_creation(): # Adding a phoney ODE ode = gotran.ODE("test") # Add states and parameters j = ode.add_state("j", 1.0) i = ode.add_state("i", 2.0) k = ode.add_state("k", 3.0) ii = ode.add_parameter("ii", 0.0) jj = ode.add_parameter("jj", 0.0) kk = ode.add_parameter("kk", 0.0) # Try overwriting state with pytest.raises(gotran.GotranException): ode.add_parameter("j", 1.0) # Try overwriting parameter with pytest.raises(gotran.GotranException): ode.add_state("ii", 1.0) assert ode.num_states == 3 assert ode.num_parameters == 3 assert ode.present_component == ode # Add an Expression ode.alpha = i * j # Add derivatives for all states in the main component ode.add_comment("Some nice derivatives and an algebraic expression") ode.di_dt = ode.alpha + ii ode.dj_dt = -ode.alpha - jj ode.alg_k_0 = kk * k * ode.alpha assert ode.num_intermediates == 1 # Add a component with 2 states ode("jada").add_states(m=2.0, n=3.0, l=1.0, o=4.0) ode("jada").add_parameters(ll=1.0, mm=2.0) # Define a state derivative ode("jada").dm_dt = ode("jada").ll - (ode("jada").m - ode.i) jada = ode("jada") assert ode.present_component == jada # Test num_foo assert jada.num_states == 4 assert jada.num_parameters == 2 assert ode.num_states == 7 assert ode.num_parameters == 5 assert ode.num_components == 2 assert jada.num_components == 1 # Add expressions to the component jada.tmp = jada.ll * jada.m**2 + 3 / i - ii * jj jada.tmp2 = ode.j * exp(jada.tmp) # Reduce state n jada.add_solve_state(jada.n, 1 - jada.l - jada.m - jada.n) assert ode.num_intermediates == 4 # Try overwriting parameter with expression with pytest.raises(gotran.GotranException): jada.ll = jada.tmp * jada.tmp2 # Create a derivative expression ode.add_comment("More funky objects") jada.tmp3 = jada.tmp2.diff(ode.t) + jada.n + jada.o jada.add_derivative(jada.l, ode.t, jada.tmp3) jada.add_algebraic(jada.o, jada.o**2 - exp(jada.o) + 2 / jada.o) assert ode.num_intermediates == 9 assert ode.num_state_expressions == 6 assert ode.is_complete assert ode.num_full_states == 6 # Try adding expressions to ode component with pytest.raises(gotran.GotranException): ode.p = 1.0 # Check used in and dependencies for one intermediate tmp3 = ode.present_ode_objects["tmp3"] assert ode.object_used_in[tmp3] == {ode.present_ode_objects["dl_dt"]} for sym in symbols_from_expr(tmp3.expr, include_derivatives=True): assert (ode.present_ode_objects[sympycode(sym)] in ode.expression_dependencies[tmp3]) # Add another component to test rates bada = ode("bada") bada.add_parameters(nn=5.0, oo=3.0, qq=1.0, pp=2.0) nada = bada.add_component("nada") nada.add_states(("r", 3.0), ("s", 4.0), ("q", 1.0), ("p", 2.0)) assert bada.num_parameters == 4 assert bada.num_states == 4 nada.p = 1 - nada.r - nada.s - nada.q assert "".join(p.name for p in ode.parameters) == "iijjkkllmmnnooppqq" assert "".join(s.name for s in ode.states) == "jiklmnorsqp" assert not ode.is_complete # Add rates to component making it a Markov model component nada.rates[nada.r, nada.s] = 3 * exp(-i) # Try add a state derivative to Markov model with pytest.raises(gotran.GotranException): nada.ds_dt = 3.0 nada.rates[nada.s, nada.r] = 2.0 nada.rates[nada.s, nada.q] = 2.0 nada.rates[nada.q, nada.s] = 2 * exp(-i) nada.rates[nada.q, nada.p] = 3.0 nada.rates[nada.p, nada.q] = 4.0 assert ode.present_component == nada markov = bada.add_component("markov_2") markov.add_states(("tt", 3.0), ("u", 4.0), ("v", 1.0)) with pytest.raises(gotran.GotranException): markov.rates[nada.s, nada.r] = 2.0 with pytest.raises(gotran.GotranException): markov.rates[markov.tt, markov.u] = 2 * exp(markov.u) with pytest.raises(gotran.GotranException): markov.rates[[markov.tt, markov.u, markov.v]] = 5.0 with pytest.raises(gotran.GotranException): markov.rates[[markov.tt, markov.u, markov.v]] = Matrix([[1, 2 * i, 0.0], [0.0, 2.0, 4.0]], ) with pytest.raises(gotran.GotranException): markov.rates[markov.tt, markov.tt] = 5.0 markov.rates[[markov.tt, markov.u, markov.v]] = Matrix( [[0.0, 2 * i, 2.0], [4.0, 0.0, 2.0], [5.0, 2.0, 0.0]], ) ode.finalize() assert ode.is_complete assert ode.is_dae # Test Mass matrix vector = ode.mass_matrix * Matrix([1] * ode.num_full_states) assert (0, 0) == (vector[2], vector[5]) assert sum(ode.mass_matrix) == ode.num_full_states - 2 assert sum(vector) == ode.num_full_states - 2 assert "".join(s.name for s in ode.full_states) == "ijkmlorsqttuv" assert ode.present_component == ode # Test saving ode.save("test_ode") # Test loading ode_loaded = gotran.load_ode("test_ode") # Clean os.unlink("test_ode.ode") # Test same signature # self.assertEqual(ode.signature(), ode_loaded.signature()) # Check that all objects are the same and evaluates to same value for name, obj in list(ode.present_ode_objects.items()): loaded_obj = ode_loaded.present_ode_objects[name] assert type(obj) == type(loaded_obj) assert loaded_obj.param.value == pytest.approx(obj.param.value)