Esempio n. 1
0
    def test_to_python(self):
        a = pybamm.StateVector(slice(0, 1))
        b = pybamm.StateVector(slice(1, 2))

        # test a * b
        expr = a + b
        constant_str, variable_str = pybamm.to_python(expr)
        expected_str = ("var_[0-9m]+ = y\[0:1\].*\\n"
                        "var_[0-9m]+ = y\[1:2\].*\\n"
                        "var_[0-9m]+ = var_[0-9m]+ \+ var_[0-9m]+")

        self.assertRegex(variable_str, expected_str)
Esempio n. 2
0
    def test_to_python(self):
        a = pybamm.StateVector(slice(0, 1))
        b = pybamm.StateVector(slice(1, 2))

        # test a * b
        expr = a + b
        constant_str, variable_str = pybamm.to_python(expr)
        expected_str = (
            "self\.var_[0-9m]+ = y\[:1\]\[\[True\]\].*\\n"
            "self\.var_[0-9m]+ = y\[:2\]\[\[False, True\]\].*\\n"
            "self\.var_[0-9m]+ = self\.var_[0-9m]+ \+ self\.var_[0-9m]+")

        self.assertRegex(variable_str, expected_str)
Esempio n. 3
0
    def __init__(self, symbol):
        constants, python_str = pybamm.to_python(symbol, debug=False)

        # extract constants in generated function
        for i, symbol_id in enumerate(constants.keys()):
            const_name = id_to_python_variable(symbol_id, True)
            python_str = "{} = constants[{}]\n".format(const_name, i) + python_str

        # constants passed in as an ordered dict, convert to list
        self._constants = list(constants.values())

        # indent code
        python_str = "   " + python_str
        python_str = python_str.replace("\n", "\n   ")

        # add function def to first line
        python_str = (
            "def evaluate(constants, t=None, y=None, "
            "y_dot=None, inputs=None, known_evals=None):\n" + python_str
        )

        # calculate the final variable that will output the result of calling `evaluate`
        # on `symbol`
        result_var = id_to_python_variable(symbol.id, symbol.is_constant())
        if symbol.is_constant():
            result_value = symbol.evaluate()

        # add return line
        if symbol.is_constant() and isinstance(result_value, numbers.Number):
            python_str = python_str + "\n   return " + str(result_value)
        else:
            python_str = python_str + "\n   return " + result_var

        # store a copy of examine_jaxpr
        python_str = python_str + "\nself._evaluate = evaluate"

        self._python_str = python_str
        self._result_var = result_var
        self._symbol = symbol

        # compile and run the generated python code,
        compiled_function = compile(python_str, result_var, "exec")
        exec(compiled_function)
Esempio n. 4
0
    def __init__(self, symbol):
        constants, self._variable_function = pybamm.to_python(symbol,
                                                              debug=False)

        # store all the constant symbols in the tree as internal variables of this
        # object
        for symbol_id, value in constants.items():
            setattr(
                self,
                id_to_python_variable(symbol_id, True).replace("self.", ""),
                value)

        # calculate the final variable that will output the result of calling `evaluate`
        # on `symbol`
        self._result_var = id_to_python_variable(symbol.id,
                                                 symbol.is_constant())

        # compile the generated python code
        self._variable_compiled = compile(self._variable_function,
                                          self._result_var, "exec")

        # compile the line that will return the output of `evaluate`
        self._return_compiled = compile(self._result_var,
                                        "return" + self._result_var, "eval")
Esempio n. 5
0
    def __init__(self, symbol):
        constants, python_str = pybamm.to_python(symbol, debug=False, output_jax=True)

        # replace numpy function calls to jax numpy calls
        python_str = python_str.replace("np.", "jax.numpy.")

        # convert all numpy constants to device vectors
        for symbol_id in constants:
            if isinstance(constants[symbol_id], np.ndarray):
                constants[symbol_id] = jax.device_put(constants[symbol_id])

        # get a list of constant arguments to input to the function
        arg_list = [
            id_to_python_variable(symbol_id, True) for symbol_id in constants.keys()
        ]

        # get a list of hashable arguments to make static
        # a jax device array is not hashable
        static_argnums = (
            i
            for i, c in enumerate(constants.values())
            if not (isinstance(c, jax.interpreters.xla.DeviceArray))
        )

        # store constants
        self._constants = tuple(constants.values())

        # indent code
        python_str = "   " + python_str
        python_str = python_str.replace("\n", "\n   ")

        # add function def to first line
        args = "t=None, y=None, y_dot=None, inputs=None, known_evals=None"
        if arg_list:
            args = ",".join(arg_list) + ", " + args
        python_str = "def evaluate_jax({}):\n".format(args) + python_str

        # calculate the final variable that will output the result of calling `evaluate`
        # on `symbol`
        result_var = id_to_python_variable(symbol.id, symbol.is_constant())
        if symbol.is_constant():
            result_value = symbol.evaluate()

        # add return line
        if symbol.is_constant() and isinstance(result_value, numbers.Number):
            python_str = python_str + "\n   return " + str(result_value)
        else:
            python_str = python_str + "\n   return " + result_var

        # store a copy of examine_jaxpr
        python_str = python_str + "\nself._evaluate_jax = evaluate_jax"

        # store the final generated code
        self._python_str = python_str

        # compile and run the generated python code,
        compiled_function = compile(python_str, result_var, "exec")
        exec(compiled_function)

        n = len(arg_list)
        static_argnums = tuple(static_argnums)
        self._jit_evaluate = jax.jit(self._evaluate_jax, static_argnums=static_argnums)

        # store a jit version of evaluate_jax's jacobian
        jacobian_evaluate = jax.jacfwd(self._evaluate_jax, argnums=1 + n)
        self._jac_evaluate = jax.jit(jacobian_evaluate, static_argnums=static_argnums)