def test_to_python(self): a = pybamm.StateVector(slice(0, 1)) b = pybamm.StateVector(slice(1, 2)) # test a * b expr = a + b constant_str, variable_str = pybamm.to_python(expr) expected_str = ("var_[0-9m]+ = y\[0:1\].*\\n" "var_[0-9m]+ = y\[1:2\].*\\n" "var_[0-9m]+ = var_[0-9m]+ \+ var_[0-9m]+") self.assertRegex(variable_str, expected_str)
def test_to_python(self): a = pybamm.StateVector(slice(0, 1)) b = pybamm.StateVector(slice(1, 2)) # test a * b expr = a + b constant_str, variable_str = pybamm.to_python(expr) expected_str = ( "self\.var_[0-9m]+ = y\[:1\]\[\[True\]\].*\\n" "self\.var_[0-9m]+ = y\[:2\]\[\[False, True\]\].*\\n" "self\.var_[0-9m]+ = self\.var_[0-9m]+ \+ self\.var_[0-9m]+") self.assertRegex(variable_str, expected_str)
def __init__(self, symbol): constants, python_str = pybamm.to_python(symbol, debug=False) # extract constants in generated function for i, symbol_id in enumerate(constants.keys()): const_name = id_to_python_variable(symbol_id, True) python_str = "{} = constants[{}]\n".format(const_name, i) + python_str # constants passed in as an ordered dict, convert to list self._constants = list(constants.values()) # indent code python_str = " " + python_str python_str = python_str.replace("\n", "\n ") # add function def to first line python_str = ( "def evaluate(constants, t=None, y=None, " "y_dot=None, inputs=None, known_evals=None):\n" + python_str ) # calculate the final variable that will output the result of calling `evaluate` # on `symbol` result_var = id_to_python_variable(symbol.id, symbol.is_constant()) if symbol.is_constant(): result_value = symbol.evaluate() # add return line if symbol.is_constant() and isinstance(result_value, numbers.Number): python_str = python_str + "\n return " + str(result_value) else: python_str = python_str + "\n return " + result_var # store a copy of examine_jaxpr python_str = python_str + "\nself._evaluate = evaluate" self._python_str = python_str self._result_var = result_var self._symbol = symbol # compile and run the generated python code, compiled_function = compile(python_str, result_var, "exec") exec(compiled_function)
def __init__(self, symbol): constants, self._variable_function = pybamm.to_python(symbol, debug=False) # store all the constant symbols in the tree as internal variables of this # object for symbol_id, value in constants.items(): setattr( self, id_to_python_variable(symbol_id, True).replace("self.", ""), value) # calculate the final variable that will output the result of calling `evaluate` # on `symbol` self._result_var = id_to_python_variable(symbol.id, symbol.is_constant()) # compile the generated python code self._variable_compiled = compile(self._variable_function, self._result_var, "exec") # compile the line that will return the output of `evaluate` self._return_compiled = compile(self._result_var, "return" + self._result_var, "eval")
def __init__(self, symbol): constants, python_str = pybamm.to_python(symbol, debug=False, output_jax=True) # replace numpy function calls to jax numpy calls python_str = python_str.replace("np.", "jax.numpy.") # convert all numpy constants to device vectors for symbol_id in constants: if isinstance(constants[symbol_id], np.ndarray): constants[symbol_id] = jax.device_put(constants[symbol_id]) # get a list of constant arguments to input to the function arg_list = [ id_to_python_variable(symbol_id, True) for symbol_id in constants.keys() ] # get a list of hashable arguments to make static # a jax device array is not hashable static_argnums = ( i for i, c in enumerate(constants.values()) if not (isinstance(c, jax.interpreters.xla.DeviceArray)) ) # store constants self._constants = tuple(constants.values()) # indent code python_str = " " + python_str python_str = python_str.replace("\n", "\n ") # add function def to first line args = "t=None, y=None, y_dot=None, inputs=None, known_evals=None" if arg_list: args = ",".join(arg_list) + ", " + args python_str = "def evaluate_jax({}):\n".format(args) + python_str # calculate the final variable that will output the result of calling `evaluate` # on `symbol` result_var = id_to_python_variable(symbol.id, symbol.is_constant()) if symbol.is_constant(): result_value = symbol.evaluate() # add return line if symbol.is_constant() and isinstance(result_value, numbers.Number): python_str = python_str + "\n return " + str(result_value) else: python_str = python_str + "\n return " + result_var # store a copy of examine_jaxpr python_str = python_str + "\nself._evaluate_jax = evaluate_jax" # store the final generated code self._python_str = python_str # compile and run the generated python code, compiled_function = compile(python_str, result_var, "exec") exec(compiled_function) n = len(arg_list) static_argnums = tuple(static_argnums) self._jit_evaluate = jax.jit(self._evaluate_jax, static_argnums=static_argnums) # store a jit version of evaluate_jax's jacobian jacobian_evaluate = jax.jacfwd(self._evaluate_jax, argnums=1 + n) self._jac_evaluate = jax.jit(jacobian_evaluate, static_argnums=static_argnums)