Ejemplo n.º 1
0
    def get_integer(
        self, tree: Union[ast.Primary, ast.ComponentRef, ast.Expression,
                          ast.Slice]
    ) -> Union[int, ca.MX, np.ndarray]:
        # CasADi needs to know the dimensions of symbols at instantiation.
        # We therefore need a mechanism to evaluate expressions that define dimensions of symbols.
        if isinstance(tree, ast.Primary):
            return None if tree.value is None else int(tree.value)
        if isinstance(tree, ast.ComponentRef):
            s = self.current_class.symbols[tree.name]
            assert (s.type.name == 'Integer')
            return self.get_integer(s.value)
        if isinstance(tree, ast.Expression):
            # Make sure that the expression has been converted to MX by (re)visiting the
            # relevant part of the AST.
            ast_walker = TreeWalker()
            ast_walker.walk(self, tree)

            # Obtain expression
            expr = self.get_mx(tree)

            # Obtain the symbols it depends on
            free_vars = ca.symvar(expr)

            # Find the values of the symbols
            vals = []
            for free_var in free_vars:
                if free_var.is_symbolic():
                    if (len(self.for_loops) > 0) and (
                            free_var.name() == self.for_loops[-1].name):
                        vals.append(self.for_loops[-1].index_variable)
                    else:
                        vals.append(
                            self.get_integer(self.current_class.symbols[
                                free_var.name()].value))

            # Evaluate the expression
            F = ca.Function('get_integer', free_vars, [expr])
            ret = F.call(vals, *self.function_mode)
            if ret[0].is_constant():
                # We managed to evaluate the expression.  Assume the result to be integer.
                return int(ret[0])
            else:
                # Expression depends on other symbols.  Could not extract integer value.
                return ret[0]
        if isinstance(tree, ast.Slice):
            start = self.get_integer(tree.start)
            step = self.get_integer(tree.step)
            stop = self.get_integer(tree.stop)
            return slice(start, stop, step)
        else:
            raise Exception('Unexpected node type {}'.format(
                tree.__class__.__name__))
Ejemplo n.º 2
0
def generate(ast_tree: ast.Tree, model_name: str):
    """
    :param ast_tree: AST to generate from
    :param model_name: class to generate
    :return: sympy source code for model
    """
    component_ref = ast.ComponentRef.from_string(model_name)
    ast_tree_new = copy.deepcopy(ast_tree)
    ast_walker = TreeWalker()
    flat_tree = flatten(ast_tree_new, component_ref)
    sympy_gen = SympyGenerator()
    ast_walker.walk(sympy_gen, flat_tree)
    return sympy_gen.src[flat_tree]
Ejemplo n.º 3
0
def generate(ast_tree: ast.Tree, model_name: str):
    """
    :param ast_tree: AST to generate from
    :param model_name: class to generate
    :return: sympy source code for model
    """
    component_ref = ast.ComponentRef.from_string(model_name)
    ast_tree_new = copy.deepcopy(ast_tree)
    ast_walker = TreeWalker()
    flat_tree = flatten(ast_tree_new, component_ref)
    gen = XmlGenerator()
    ast_walker.walk(gen, flat_tree)
    return etree.tostring(gen.xml[flat_tree], pretty_print=True).decode('utf-8')
Ejemplo n.º 4
0
def generate(ast_tree: ast.Tree, model_name: str):
    """
    :param ast_tree: AST to generate from
    :param model_name: class to generate
    :return: sympy source code for model
    """
    component_ref = ast.ComponentRef.from_string(model_name)
    ast_tree_new = copy.deepcopy(ast_tree)
    ast_walker = TreeWalker()
    flat_tree = flatten(ast_tree_new, component_ref)
    gen = XmlGenerator()
    ast_walker.walk(gen, flat_tree)
    return etree.tostring(gen.xml[flat_tree], pretty_print=True).decode('utf-8')
Ejemplo n.º 5
0
def generate(ast_tree: ast.Tree, model_name: str):
    """
    :param ast_tree: AST to generate from
    :param model_name: class to generate
    :return: sympy source code for model
    """
    component_ref = ast.ComponentRef.from_string(model_name)
    ast_tree_new = copy.deepcopy(ast_tree)
    ast_walker = TreeWalker()
    flat_tree = flatten(ast_tree_new, component_ref)
    sympy_gen = SympyGenerator()
    ast_walker.walk(sympy_gen, flat_tree)
    return sympy_gen.src[flat_tree]
Ejemplo n.º 6
0
def generate(ast_tree: ast.Tree,
             model_name: str,
             options: Dict[str, Union[bool, str]] = None) -> str:
    """
    :param ast_tree: AST to generate from
    :param model_name: class to generate
    :param options: options to pass to generator
    :return: sympy source code for model
    """
    _ = options  # Unused
    component_ref = ast.ComponentRef.from_string(model_name)
    ast_tree_new = copy.deepcopy(ast_tree)
    ast_walker = TreeWalker()
    flat_tree = flatten(ast_tree_new, component_ref)
    sympy_gen = SympyGenerator()
    ast_walker.walk(sympy_gen, flat_tree)
    return sympy_gen.src[flat_tree]
Ejemplo n.º 7
0
    def get_integer(self, tree: Union[ast.Primary, ast.ComponentRef, ast.Expression, ast.Slice]) -> Union[int, ca.MX, np.ndarray]:
        # CasADi needs to know the dimensions of symbols at instantiation.
        # We therefore need a mechanism to evaluate expressions that define dimensions of symbols.
        if isinstance(tree, ast.Primary):
            return None if tree.value is None else int(tree.value)
        if isinstance(tree, ast.ComponentRef):
            s = self.current_class.symbols[tree.name]
            assert (s.type.name == 'Integer')
            return self.get_integer(s.value)
        if isinstance(tree, ast.Expression):
            # Make sure that the expression has been converted to MX by (re)visiting the
            # relevant part of the AST.
            ast_walker = TreeWalker()
            ast_walker.walk(self, tree)

            # Obtain expression
            expr = self.get_mx(tree)

            # Obtain the symbols it depends on
            free_vars = ca.symvar(expr)

            # Find the values of the symbols
            vals = []
            for free_var in free_vars:
                if free_var.is_symbolic():
                    if (len(self.for_loops) > 0) and (free_var.name() == self.for_loops[-1].name):
                        vals.append(self.for_loops[-1].index_variable)
                    else:
                        vals.append(self.get_integer(self.current_class.symbols[free_var.name()].value))

            # Evaluate the expression
            F = ca.Function('get_integer', free_vars, [expr])
            ret = F.call(vals, *self.function_mode)
            if ret[0].is_constant():
                # We managed to evaluate the expression.  Assume the result to be integer.
                return int(ret[0])
            else:
                # Expression depends on other symbols.  Could not extract integer value.
                return ret[0]
        if isinstance(tree, ast.Slice):
            start = self.get_integer(tree.start)
            step = self.get_integer(tree.step)
            stop = self.get_integer(tree.stop)
            return slice(start, stop, step)
        else:
            raise Exception('Unexpected node type {}'.format(tree.__class__.__name__))
Ejemplo n.º 8
0
def generate(ast_tree: ast.Tree,
             model_name: str,
             options: Dict[str, bool] = None) -> Model:
    """
    :param ast_tree: AST to generate from
    :param model_name: class to generate
    :param options: dictionary of generator options
    :return: casadi model
    """
    if options is None:
        options = {}
    component_ref = ast.ComponentRef.from_string(model_name)
    ast_walker = TreeWalker()
    flat_tree = flatten(ast_tree, component_ref)
    component_ref_tuple = component_ref.to_tuple()
    casadi_gen = Generator(flat_tree, component_ref_tuple[-1], options)
    ast_walker.walk(casadi_gen, flat_tree)
    return casadi_gen.model