예제 #1
0
    def get_derivative(self, s):

        # Case 1: s is a constant, e.g. MX(5)
        if ca.MX(s).is_constant():
            return 0

        # Case 2: s is a symbol, e.g. MX(x)
        elif s.is_symbolic():
            if s.name() not in self.derivative:
                if len(self.for_loops
                       ) > 0 and s in self.for_loops[-1].indexed_symbols:
                    # Create a new indexed symbol, referencing to the for loop index inside the vector derivative symbol.
                    for_loop_symbol = self.for_loops[-1].indexed_symbols[s]
                    s_without_index = self.get_mx(
                        ast.ComponentRef(name=for_loop_symbol.tree.name))
                    der_s_without_index = self.get_derivative(s_without_index)
                    if ca.MX(der_s_without_index).is_symbolic():
                        return self.get_indexed_symbol(
                            ast.ComponentRef(
                                name=der_s_without_index.name(),
                                indices=for_loop_symbol.tree.indices),
                            der_s_without_index)
                    else:
                        return 0
                else:
                    der_s = ca.MX.sym("der({})".format(s.name()), s.size())
                    self.derivative[s.name()] = der_s
                    self.nodes[self.current_class][der_s.name()] = der_s
                    return der_s
            else:
                return self.derivative[s.name()]

        # Case 3: s is an already indexed symbol, e.g. MX(x[1])
        elif s.is_op(ca.OP_GETNONZEROS) and s.dep().is_symbolic():
            slice_info = s.info()['slice']
            dep = s.dep()
            if dep.name() not in self.derivative:
                der_dep = ca.MX.sym("der({})".format(dep.name()), dep.size())
                self.derivative[dep.name()] = der_dep
                return der_dep[
                    slice_info['start']:slice_info['stop']:slice_info['step']]
            else:
                return self.derivative[dep.name(
                )][slice_info['start']:slice_info['stop']:slice_info['step']]

        # Case 4: s is an expression that requires differentiation, e.g. MX(x2 * x2)
        # Need to do this sort of expansion: der(x1 * x2) = der(x1) * x2 + x1 * der(x2)
        else:
            # Differentiate expression using CasADi
            orig_deps = ca.symvar(s)
            deps = ca.vertcat(*orig_deps)
            J = ca.Function('J', [deps], [ca.jacobian(s, deps)])
            J_sparsity = J.sparsity_out(0)
            der_deps = [
                self.get_derivative(dep)
                if J_sparsity.has_nz(0, j) else ca.DM.zeros(dep.size())
                for j, dep in enumerate(orig_deps)
            ]
            return ca.mtimes(J(deps), ca.vertcat(*der_deps))
예제 #2
0
    def test_extends_modification(self):
        with open(os.path.join(MODEL_DIR, 'ExtendsModification.mo'), 'r') as f:
            txt = f.read()
        ast_tree = parser.parse(txt)
        flat_tree = tree.flatten(ast_tree, ast.ComponentRef(name='MainModel'))

        self.assertEqual(flat_tree.classes['MainModel'].symbols['e.HQ.H'].min.name, "e.H_b")
예제 #3
0
    def test_inheritance_symbol_modifiers(self):
        with open(os.path.join(MODEL_DIR, 'Inheritance.mo'), 'r') as f:
            txt = f.read()
        ast_tree = parser.parse(txt)
        flat_tree = tree.flatten(ast_tree, ast.ComponentRef(name='Sub'))

        self.assertEqual(flat_tree.classes['Sub'].symbols['x'].max.value, 30.0)
예제 #4
0
 def test_duplicate_state(self):
     with open(os.path.join(MODEL_DIR, 'DuplicateState.mo'), 'r') as f:
         txt = f.read()
     ast_tree = parser.parse(txt)
     print('AST TREE\n', ast_tree)
     flat_tree = tree.flatten(ast_tree, ast.ComponentRef(name='DuplicateState'))
     print('AST TREE FLAT\n', flat_tree)
     self.flush()
예제 #5
0
 def test_spring_system(self):
     with open(os.path.join(MODEL_DIR, 'SpringSystem.mo'), 'r') as f:
         txt = f.read()
     ast_tree = parser.parse(txt)
     print('AST TREE\n', ast_tree)
     flat_tree = tree.flatten(ast_tree, ast.ComponentRef(name='SpringSystem'))
     print('AST TREE FLAT\n', flat_tree)
     self.flush()
예제 #6
0
 def test_estimator(self):
     with open(os.path.join(MODEL_DIR, './Estimator.mo'), 'r') as f:
         txt = f.read()
     ast_tree = parser.parse(txt)
     print('AST TREE\n', ast_tree)
     flat_tree = tree.flatten(ast_tree, ast.ComponentRef(name='Estimator'))
     print('AST TREE FLAT\n', flat_tree)
     self.flush()
예제 #7
0
    def test_nested_classes(self):
        with open(os.path.join(MODEL_DIR, 'NestedClasses.mo'), 'r') as f:
            txt = f.read()
        ast_tree = parser.parse(txt)
        flat_tree = tree.flatten(ast_tree, ast.ComponentRef(name='C2'))

        self.assertEqual(flat_tree.classes['C2'].symbols['v1'].nominal.value, 1000.0)
        self.assertEqual(flat_tree.classes['C2'].symbols['v2'].nominal.value, 1000.0)
예제 #8
0
 def test_aircraft(self):
     with open(os.path.join(TEST_DIR, 'Aircraft.mo'), 'r') as f:
         txt = f.read()
     ast_tree = parser.parse(txt)
     print('AST TREE\n', ast_tree)
     flat_tree = tree.flatten(ast_tree, ast.ComponentRef(name='Aircraft'))
     print('AST TREE FLAT\n', flat_tree)
     self.flush()
예제 #9
0
 def test_bouncing_ball(self):
     with open(os.path.join(MODEL_DIR, 'BouncingBall.mo'), 'r') as f:
         txt = f.read()
     ast_tree = parser.parse(txt)
     print('AST TREE\n', ast_tree)
     flat_tree = tree.flatten(ast_tree, ast.ComponentRef(name='BouncingBall'))
     print(flat_tree)
     print('AST TREE FLAT\n', flat_tree)
     self.flush()
예제 #10
0
    def test_inheritance(self):
        with open(os.path.join(MODEL_DIR, 'InheritanceInstantiation.mo'), 'r') as f:
            txt = f.read()
        ast_tree = parser.parse(txt)
        flat_tree = tree.flatten(ast_tree, ast.ComponentRef(name='C2'))

        self.assertEqual(flat_tree.classes['C2'].symbols['bcomp1.b'].value.value, 3.0)
        self.assertEqual(flat_tree.classes['C2'].symbols['bcomp3.a'].value.value, 1.0)
        self.assertEqual(flat_tree.classes['C2'].symbols['bcomp3.b'].value.value, 2.0)
예제 #11
0
 def get_derivative(self, s):
     if ca.MX(s).is_constant():
         return 0
     elif s.is_symbolic():
         if s.name() not in self.derivative:
             if len(self.for_loops
                    ) > 0 and s in self.for_loops[-1].indexed_symbols:
                 # Create a new indexed symbol, referencing to the for loop index inside the vector derivative symbol.
                 for_loop_symbol = self.for_loops[-1].indexed_symbols[s]
                 s_without_index = self.get_mx(
                     ast.ComponentRef(name=for_loop_symbol.tree.name))
                 der_s_without_index = self.get_derivative(s_without_index)
                 if ca.MX(der_s_without_index).is_symbolic():
                     return self.get_indexed_symbol(
                         ast.ComponentRef(
                             name=der_s_without_index.name(),
                             indices=for_loop_symbol.tree.indices),
                         der_s_without_index)
                 else:
                     return 0
             else:
                 der_s = ca.MX.sym("der({})".format(s.name()), s.size())
                 self.derivative[s.name()] = der_s
                 self.nodes[self.current_class][der_s.name()] = der_s
                 return der_s
         else:
             return self.derivative[s.name()]
     else:
         # Differentiate expression using CasADi
         orig_deps = ca.symvar(s)
         deps = ca.vertcat(*orig_deps)
         J = ca.Function('J', [deps], [ca.jacobian(s, deps)])
         J_sparsity = J.sparsity_out(0)
         der_deps = [
             self.get_derivative(dep)
             if J_sparsity.has_nz(0, j) else ca.DM.zeros(dep.size())
             for j, dep in enumerate(orig_deps)
         ]
         return ca.mtimes(J(deps), ca.vertcat(*der_deps))
예제 #12
0
 def test_spring(self):
     with open(os.path.join(TEST_DIR, 'SpringSystem.mo'), 'r') as f:
         txt = f.read()
     ast_tree = parser.parse(txt)
     flat_tree = tree.flatten(ast_tree,
                              ast.ComponentRef(name='SpringSystem'))
     print(flat_tree)
     text = gen_sympy.generate(ast_tree, 'SpringSystem')
     with open(os.path.join(TEST_DIR, 'generated/Spring.py'), 'w') as f:
         f.write(text)
     from test.generated.Spring import SpringSystem as SpringSystem
     e = SpringSystem()
     e.linearize_symbolic()
     e.linearize()
     # noinspection PyUnusedLocal
     res = e.simulate(x0=[1.0, 1.0])
     self.flush()
예제 #13
0
    def test_connector(self):
        with open(os.path.join(TEST_DIR, 'Connector.mo'), 'r') as f:
            txt = f.read()
        ast_tree = parser.parse(txt)
        # print(ast_tree)

        # noinspection PyUnusedLocal
        flat_tree = tree.flatten(ast_tree, ast.ComponentRef(name='Aircraft'))
        # print(flat_tree)

        # noinspection PyUnusedLocal
        walker = tree.TreeWalker()
        # noinspection PyUnusedLocal
        classes = ast_tree.classes
        # noinspection PyUnusedLocal
        root = ast_tree.classes['Aircraft']

        # instantiator = tree.Instantiator(classes=classes)
        # walker.walk(instantiator, root)
        # print(instantiator.res[root].symbols.keys())
        # print(instantiator.res[root])

        # print('INSTANTIATOR\n-----------\n\n')
        # print(instantiator.res[root])

        # connectExpander = tree.ConnectExpander(classes=classes)
        # walker.walk(connectExpander, instantiator.res[root])

        # print('CONNECT EXPANDER\n-----------\n\n')
        # print(connectExpander.new_class)

        # text = gen_sympy.generate(ast_tree, 'Aircraft')
        # print(text)
        # with open(os.path.join(TEST_DIR, 'generated/Connect.py'), 'w') as f:
        #    f.write(text)

        # from generated.Connect import Aircraft as Aircraft
        # e = Aircraft()
        # res = e.simulate()
        self.flush()