Exemplo n.º 1
0
    def test_promote_demote(self):
        # Test variable promotion and demotion.

        m = myokit.Model()
        c = m.add_component('c')
        v = c.add_variable('v')
        v.set_rhs(3)

        self.assertTrue(v.is_literal())
        self.assertTrue(v.is_constant())
        self.assertFalse(v.is_intermediary())
        self.assertFalse(v.is_state())
        self.assertEqual(v.lhs(), myokit.Name(v))
        self.assertRaises(Exception, v.demote)
        self.assertRaises(Exception, v.indice)
        self.assertRaises(Exception, v.state_value)

        v.promote(3)
        self.assertFalse(v.is_literal())
        self.assertFalse(v.is_constant())
        self.assertFalse(v.is_intermediary())
        self.assertTrue(v.is_state())
        self.assertEqual(v.lhs(), myokit.Derivative(myokit.Name(v)))
        self.assertEqual(v.indice(), 0)
        self.assertEqual(v.state_value(), 3)

        v.demote()
        self.assertTrue(v.is_literal())
        self.assertTrue(v.is_constant())
        self.assertFalse(v.is_intermediary())
        self.assertFalse(v.is_state())
        self.assertEqual(v.lhs(), myokit.Name(v))
        self.assertRaises(Exception, v.demote)
        self.assertRaises(Exception, v.indice)
        self.assertRaises(Exception, v.state_value)

        # Test errors
        v.promote(3)
        self.assertRaisesRegex(Exception, 'already', v.promote, 4)
        v.demote()
        v.set_binding('time')
        self.assertRaisesRegex(Exception, 'cannot be bound', v.promote, 4)
        w = v.add_variable('w')
        self.assertRaisesRegex(
            Exception, 'only be added to Components', w.promote, 4)

        # Test we can't demote a variable with references to its derivative
        m = myokit.Model()
        c = m.add_component('c')
        x = c.add_variable('x')
        x.set_rhs(3)
        x.promote()
        y = c.add_variable('y')
        y.set_rhs('1 + dot(x)')
        self.assertRaisesRegex(
            Exception, 'references to its derivative', x.demote)
        y.set_rhs('1 + x')
        x.demote()
Exemplo n.º 2
0
 def _ex_name(self, e):
     var = str(e)
     # Check if this is a derivative
     # See :meth:`SymPyExpressionWriter._ex_derivative()`.
     if self._model:
         if var[:4] == 'dot(' and var[-1:] == ')':
             var = self._model.get(var[4:-1], myokit.Variable)
             return myokit.Derivative(myokit.Name(var))
         var = self._model.get(var, myokit.Variable)
     return myokit.Name(var)
Exemplo n.º 3
0
    def _maths(self, parent, component):
        """
        Adds a ``math`` element to the given ``parent`` containing the maths
        for the variables in ``component``.
        """

        # Test if this component has maths
        has_maths = False
        for v in component:
            if v.rhs() is not None:
                has_maths = True
                break
        if not has_maths:
            return

        # Find free variable alias in local component
        # In valid models, this will always be set if states are present in
        # this component.
        free = None
        for v in component:
            if v.value_source().is_free():
                free = v
                break

        # Create expression writer for this component
        from myokit.formats.cellml import CellMLExpressionWriter
        ewriter = CellMLExpressionWriter(component.model().version())
        ewriter.set_lhs_function(lambda x: x.var().name())
        ewriter.set_unit_function(lambda x: component.find_units_name(x))
        if free is not None:
            ewriter.set_time_variable(free)

        # Reset default namespace to MathML namespace
        nsmap = {None: cellml.NS_MATHML}
        if component.model().version() == '1.0':
            nsmap['cellml'] = cellml.NS_CELLML_1_0
        else:
            nsmap['cellml'] = cellml.NS_CELLML_1_1

        # Create math elements
        math = etree.SubElement(parent, 'math', nsmap=nsmap)

        # Add maths for variables
        for variable in sorted(component, key=_name):
            # Check RHS
            rhs = variable.rhs()
            if rhs is None:
                continue

            # Get LHS
            lhs = myokit.Name(variable)
            if variable.is_state():
                lhs = myokit.Derivative(lhs)

            ewriter.eq(myokit.Equation(lhs, variable.rhs()), math)
Exemplo n.º 4
0
    def _test_maths(self, version):
        # Test maths is written (in selected ``version``)

        # Create model
        m1 = cellml.Model('m', version)
        c1 = m1.add_component('c')
        p1 = c1.add_variable('p', 'mole')
        p1.set_initial_value(2)
        q1 = c1.add_variable('q', 'dimensionless')
        r1 = c1.add_variable('r', 'second')
        r1.set_initial_value(0.1)
        t1 = c1.add_variable('t', 'second')
        m1.set_variable_of_integration(t1)

        # Add component without maths
        d1 = m1.add_component('d')
        s1 = d1.add_variable('s', 'volt')
        s1.set_initial_value(1.23)

        # Add two equations
        # Note: Numbers without units become dimensionless in CellML
        eq1 = myokit.Equation(
            myokit.Name(q1),
            myokit.Plus(myokit.Number(3, myokit.units.mole), myokit.Name(p1)))
        er1 = myokit.Equation(myokit.Derivative(myokit.Name(r1)),
                              myokit.Power(myokit.Name(q1), myokit.Number(2)))
        q1.set_equation(eq1)
        r1.set_equation(er1)

        # Write and read
        xml = cellml.write_string(m1)
        m2 = cellml.parse_string(xml)

        # Check results
        p2, q2, r2, s2 = m2['c']['p'], m2['c']['q'], m2['c']['r'], m2['d']['s']
        subst = {
            myokit.Name(p1): myokit.Name(p2),
            myokit.Name(q1): myokit.Name(q2),
            myokit.Name(r1): myokit.Name(r2),
        }
        eq2 = eq1.clone(subst)
        er2 = er1.clone(subst)

        self.assertEqual(q2.equation(), eq2)
        self.assertEqual(r2.equation(), er2)
        self.assertEqual(s2.initial_value(),
                         myokit.Number(1.23, myokit.units.volt))
        self.assertFalse(p2.is_state())
        self.assertFalse(q2.is_state())
        self.assertTrue(r2.is_state())
        self.assertFalse(s2.is_state())
        self.assertIs(m2.variable_of_integration(), m2['c']['t'])
Exemplo n.º 5
0
    def _parse_derivative(self, element, iterator):
        """
        Parses the elements folling a ``<diff>`` element.

        Arguments

        ``element``
            A ``<diff>`` element
        ``iterator``
            An iterator pointing at the next element.

        """

        # Get free variable
        bvar = self._next(iterator, 'bvar')
        if bvar is None:
            raise MathMLError('<diff> element must contain a <bvar>.', element)
        ci = self._next(iter(bvar), 'ci')
        if ci is None:
            raise MathMLError('<bvar> element must contain a <ci>', element)
        self._free_variables.add(self._parse_name(ci))

        # Check degree, if given
        degree = self._next(iter(bvar), 'degree')
        if degree is not None:
            cn = self._next(iter(degree), 'cn')
            if cn is None:
                raise MathMLError('<degree> element must contain a <cn>.',
                                  degree)
            d = self._parse_number(cn)
            if d.eval() != 1:
                raise MathMLError(
                    'Only derivatives of degree one are supported.', cn)

        # Get Name object
        ci = self._next(iterator, 'ci')
        if ci is None:
            raise MathMLError(
                '<diff> element must contain a <ci> after its <bvar>'
                ' element (derivatives of expressions are not supported.',
                element)
        var = self._parse_name(ci)

        return myokit.Derivative(var)
    def parsex(node):
        """
        Parses a mathml expression.
        """
        def chain(kind, node, unary=None):
            """
            Parses operands for chained operators (for example plus, minus,
            times and division).

            The argument ``kind`` must be the myokit expression type being
            parsed, ``node`` is a DOM node and ``unary``, if given, should be
            the unary expression type (unary Plus or unary Minus).
            """
            ops = []
            node = dom_next(node)
            while node:
                ops.append(parsex(node))
                node = dom_next(node)
            n = len(ops)
            if n < 1:
                raise MathMLError('Operator needs at least one operand.')
            if n < 2:
                if unary:
                    return unary(ops[0])
                else:
                    raise MathMLError('Operator needs at least two operands')
            ex = kind(ops[0], ops[1])
            for i in range(2, n):
                ex = kind(ex, ops[i])
            return ex

        # Start parsing
        name = node.tagName
        if name == 'apply':
            # Brackets, can be ignored in an expression tree.
            return parsex(dom_child(node))

        elif name == 'ci':
            # Reference
            var = str(node.firstChild.data).strip()
            if var_table is not None:
                try:
                    var = var_table[var]
                except KeyError:
                    if logger:
                        logger.warn('Unable to resolve reference to <' +
                                    str(var) + '>.')
            return myokit.Name(var)

        elif name == 'diff':
            # Derivative
            # Check time variable
            bvar = dom_next(node, 'bvar')
            if derivative_post_processor:
                derivative_post_processor(parsex(dom_child(bvar, 'ci')))

            # Check degree, if given
            d = dom_child(bvar, 'degree')
            if d is not None:
                d = parsex(dom_child(d, 'cn')).eval()
                if not d == 1:
                    raise MathMLError(
                        'Only derivatives of degree one are supported.')

            # Create derivative and return
            x = dom_next(node, 'ci')
            if x is None:
                raise MathMLError(
                    'Derivative of an expression found: only derivatives of'
                    ' variables are supported.')
            return myokit.Derivative(parsex(x))

        elif name == 'cn':
            # Number
            number = parse_mathml_number(node, logger)
            if number_post_processor:
                return number_post_processor(node, number)
            return number

        #
        # Algebra
        #

        elif name == 'plus':
            return chain(myokit.Plus, node, myokit.PrefixPlus)

        elif name == 'minus':
            return chain(myokit.Minus, node, myokit.PrefixMinus)

        elif name == 'times':
            return chain(myokit.Multiply, node)

        elif name == 'divide':
            return chain(myokit.Divide, node)

        #
        # Functions
        #

        elif name == 'exp':
            return myokit.Exp(parsex(dom_next(node)))

        elif name == 'ln':
            return myokit.Log(parsex(dom_next(node)))

        elif name == 'log':
            if dom_next(node).tagName != 'logbase':
                return myokit.Log10(parsex(dom_next(node)))
            else:
                return myokit.Log(parsex(dom_next(dom_next(node))),
                                  parsex(dom_child(dom_next(node))))

        elif name == 'root':
            # Check degree, if given
            nxt = dom_next(node)
            if nxt.tagName == 'degree':
                # Degree given, return x^(1/d) unless d is 2
                d = parsex(dom_child(nxt))
                x = parsex(dom_next(nxt))
                if d.is_literal() and d.eval() == 2:
                    return myokit.Sqrt(x)
                return myokit.Power(x, myokit.Divide(myokit.Number(1), d))
            else:
                return myokit.Sqrt(parsex(nxt))

        elif name == 'power':
            n2 = dom_next(node)
            return myokit.Power(parsex(n2), parsex(dom_next(n2)))

        elif name == 'floor':
            return myokit.Floor(parsex(dom_next(node)))

        elif name == 'ceiling':
            return myokit.Ceil(parsex(dom_next(node)))

        elif name == 'abs':
            return myokit.Abs(parsex(dom_next(node)))

        elif name == 'quotient':
            n2 = dom_next(node)
            return myokit.Quotient(parsex(n2), parsex(dom_next(n2)))

        elif name == 'rem':
            n2 = dom_next(node)
            return myokit.Remainder(parsex(n2), parsex(dom_next(n2)))

        #
        # Trigonometry
        #

        elif name == 'sin':
            return myokit.Sin(parsex(dom_next(node)))

        elif name == 'cos':
            return myokit.Cos(parsex(dom_next(node)))

        elif name == 'tan':
            return myokit.Tan(parsex(dom_next(node)))

        elif name == 'arcsin':
            return myokit.ASin(parsex(dom_next(node)))

        elif name == 'arccos':
            return myokit.ACos(parsex(dom_next(node)))

        elif name == 'arctan':
            return myokit.ATan(parsex(dom_next(node)))

        #
        # Redundant trigonometry (CellML includes this)
        #

        elif name == 'csc':
            # Cosecant: csc(x) = 1 / sin(x)
            return myokit.Divide(myokit.Number(1),
                                 myokit.Sin(parsex(dom_next(node))))

        elif name == 'sec':
            # Secant: sec(x) = 1 / cos(x)
            return myokit.Divide(myokit.Number(1),
                                 myokit.Cos(parsex(dom_next(node))))

        elif name == 'cot':
            # Contangent: cot(x) = 1 / tan(x)
            return myokit.Divide(myokit.Number(1),
                                 myokit.Tan(parsex(dom_next(node))))

        elif name == 'arccsc':
            # ArcCosecant: acsc(x) = asin(1/x)
            return myokit.ASin(
                myokit.Divide(myokit.Number(1), parsex(dom_next(node))))

        elif name == 'arcsec':
            # ArcSecant: asec(x) = acos(1/x)
            return myokit.ACos(
                myokit.Divide(myokit.Number(1), parsex(dom_next(node))))

        elif name == 'arccot':
            # ArcCotangent: acot(x) = atan(1/x)
            return myokit.ATan(
                myokit.Divide(myokit.Number(1), parsex(dom_next(node))))

        #
        # Hyperbolic trigonometry (CellML again)
        #

        elif name == 'sinh':
            # Hyperbolic sine: sinh(x) = 0.5 * (e^x - e^-x)
            x = parsex(dom_next(node))
            return myokit.Multiply(
                myokit.Number(0.5),
                myokit.Minus(myokit.Exp(x), myokit.Exp(myokit.PrefixMinus(x))))

        elif name == 'cosh':
            # Hyperbolic cosine: cosh(x) = 0.5 * (e^x + e^-x)
            x = parsex(dom_next(node))
            return myokit.Multiply(
                myokit.Number(0.5),
                myokit.Plus(myokit.Exp(x), myokit.Exp(myokit.PrefixMinus(x))))

        elif name == 'tanh':
            # Hyperbolic tangent: tanh(x) = (e^2x - 1) / (e^2x + 1)
            x = parsex(dom_next(node))
            e2x = myokit.Exp(myokit.Multiply(myokit.Number(2), x))
            return myokit.Divide(myokit.Minus(e2x, myokit.Number(1)),
                                 myokit.Plus(e2x, myokit.Number(1)))

        #
        # Inverse hyperbolic trigonometry (CellML...)
        #

        elif name == 'arcsinh':
            # Inverse hyperbolic sine: asinh(x) = log(x + sqrt(1 + x*x))
            x = parsex(dom_next(node))
            return myokit.Log(
                myokit.Plus(
                    x,
                    myokit.Sqrt(
                        myokit.Plus(myokit.Number(1), myokit.Multiply(x, x)))))

        elif name == 'arccosh':
            # Inverse hyperbolic cosine:
            #   acosh(x) = log(x + sqrt(x + 1) * sqrt(x - 1))
            x = parsex(dom_next(node))
            return myokit.Log(
                myokit.Plus(
                    x,
                    myokit.Multiply(
                        myokit.Sqrt(myokit.Plus(x, myokit.Number(1))),
                        myokit.Sqrt(myokit.Minus(x, myokit.Number(1))))))

        elif name == 'arctanh':
            # Inverse hyperbolic tangent:
            #   atanh(x) = 0.5 * (log(1 + x) - log(1 - x))
            x = parsex(dom_next(node))
            return myokit.Multiply(
                myokit.Number(0.5),
                myokit.Minus(myokit.Log(myokit.Plus(myokit.Number(1), x)),
                             myokit.Log(myokit.Minus(myokit.Number(1), x))))

        #
        # Hyperbolic redundant trigonometry (CellML...)
        #

        elif name == 'csch':
            # Hyperbolic cosecant: csch(x) = 2 / (exp(x) - exp(-x))
            x = parsex(dom_next(node))
            return myokit.Divide(
                myokit.Number(2),
                myokit.Minus(myokit.Exp(x), myokit.Exp(myokit.PrefixMinus(x))))

        elif name == 'sech':
            # Hyperbolic secant: sech(x) = 2 / (exp(x) + exp(-x))
            x = parsex(dom_next(node))
            return myokit.Divide(
                myokit.Number(2),
                myokit.Plus(myokit.Exp(x), myokit.Exp(myokit.PrefixMinus(x))))

        elif name == 'coth':
            # Hyperbolic cotangent:
            #   coth(x) = (exp(2*x) + 1) / (exp(2*x) - 1)
            x = parsex(dom_next(node))
            e2x = myokit.Exp(myokit.Multiply(myokit.Number(2), x))
            return myokit.Divide(myokit.Plus(e2x, myokit.Number(1)),
                                 myokit.Minus(e2x, myokit.Number(1)))

        #
        # Inverse hyperbolic redundant trigonometry (CellML has a lot to answer
        # for...)
        #

        elif name == 'arccsch':
            # Inverse hyperbolic cosecant:
            #   arccsch(x) = log(sqrt(1/(x*x) + 1) + 1/x)
            x = parsex(dom_next(node))
            return myokit.Log(
                myokit.Plus(
                    myokit.Sqrt(
                        myokit.Plus(
                            myokit.Divide(myokit.Number(1),
                                          myokit.Multiply(x, x)),
                            myokit.Number(1))),
                    myokit.Divide(myokit.Number(1), x)))
        elif name == 'arcsech':
            # Inverse hyperbolic secant:
            #   arcsech(x) = log(sqrt(1/(x*x) - 1) + 1/x)
            x = parsex(dom_next(node))
            return myokit.Log(
                myokit.Plus(
                    myokit.Sqrt(
                        myokit.Minus(
                            myokit.Divide(myokit.Number(1),
                                          myokit.Multiply(x, x)),
                            myokit.Number(1))),
                    myokit.Divide(myokit.Number(1), x)))
        elif name == 'arccoth':
            # Inverse hyperbolic cotangent:
            #   arccoth(x) = 0.5 * (log(3 + 1) - log(3 - 1))
            x = parsex(dom_next(node))
            return myokit.Multiply(
                myokit.Number(0.5),
                myokit.Log(
                    myokit.Divide(myokit.Plus(x, myokit.Number(1)),
                                  myokit.Minus(x, myokit.Number(1)))))

        #
        # Logic
        #

        elif name == 'and':
            return chain(myokit.And, node)

        elif name == 'or':
            return chain(myokit.Or, node)

        elif name == 'not':
            return chain(None, node, myokit.Not)

        elif name == 'eq' or name == 'equivalent':
            n2 = dom_next(node)
            return myokit.Equal(parsex(n2), parsex(dom_next(n2)))

        elif name == 'neq':
            n2 = dom_next(node)
            return myokit.NotEqual(parsex(n2), parsex(dom_next(n2)))

        elif name == 'gt':
            n2 = dom_next(node)
            return myokit.More(parsex(n2), parsex(dom_next(n2)))

        elif name == 'lt':
            n2 = dom_next(node)
            return myokit.Less(parsex(n2), parsex(dom_next(n2)))

        elif name == 'geq':
            n2 = dom_next(node)
            return myokit.MoreEqual(parsex(n2), parsex(dom_next(n2)))

        elif name == 'leq':
            n2 = dom_next(node)
            return myokit.LessEqual(parsex(n2), parsex(dom_next(n2)))

        elif name == 'piecewise':
            # Piecewise contains at least one piece, optionally contains an
            #  "otherwise". Syntax doesn't ensure this statement makes sense.
            conds = []
            funcs = []
            other = None
            piece = dom_child(node)
            while piece:
                if piece.tagName == 'otherwise':
                    if other is None:
                        other = parsex(dom_child(piece))
                    elif logger:
                        logger.warn(
                            'Multiple <otherwise> tags found in <piecewise>'
                            ' statement.')
                elif piece.tagName == 'piece':
                    n2 = dom_child(piece)
                    funcs.append(parsex(n2))
                    conds.append(parsex(dom_next(n2)))
                elif logger:
                    logger.warn('Unexpected tag type in <piecewise>: <' +
                                piece.tagName + '>.')
                piece = dom_next(piece)

            if other is None:
                if logger:
                    logger.warn('No <otherwise> tag found in <piecewise>')
                other = myokit.Number(0)

            # Create string of if statements
            args = []
            f = iter(funcs)
            for c in conds:
                args.append(c)
                args.append(next(f))
            args.append(other)
            return myokit.Piecewise(*args)

        #
        # Constants
        #

        elif name == 'pi':
            return myokit.Number('3.14159265358979323846')
        elif name == 'exponentiale':
            return myokit.Exp(myokit.Number(1))
        elif name == 'true':
            # This is corrent, even in Python True == 1 but not True == 2
            return myokit.Number(1)
        elif name == 'false':
            return myokit.Number(0)

        #
        # Unknown/unhandled elements
        #
        else:
            if logger:
                logger.warn('Unknown element: ' + name)
            ops = []
            node = dom_child(node) if dom_child(node) else dom_next(node)
            while node:
                ops.append(parsex(node))
                node = dom_next(node)
            return myokit.UnsupportedFunction(name, ops)
Exemplo n.º 7
0
    def test_sensitivities(self):
        # Test instantiation of cmodel with sensitivities

        # Bad type
        sens = 'Bad type'
        with self.assertRaisesRegex(ValueError, 'The argument `sensitivities'):
            myokit.CModel(self.model, sens)

        # Empty deps or indeps
        sens = ([], ['some parameter'])
        m = myokit.CModel(self.model, sens)
        self.assertFalse(m.has_sensitivities)
        sens = (['some state'], [])
        m = myokit.CModel(self.model, sens)
        self.assertFalse(m.has_sensitivities)

        # Provide sensitivies as Variables
        s1 = self.model.get('ik1.gK1')
        s2 = self.model.get('ikp.IKp')
        p1 = self.model.get('cell.K_o')
        p2 = self.model.get('ikp.gKp')
        sens = ([s1, s2], [p1, p2])
        m = myokit.CModel(self.model, sens)
        self.assertTrue(m.has_sensitivities)

        # Provide sensitivities as Names
        sens = ([myokit.Name(s1),
                 myokit.Name(s2)], [myokit.Name(p1),
                                    myokit.Name(p2)])
        m = myokit.CModel(self.model, sens)
        self.assertTrue(m.has_sensitivities)

        # Sensitivity of derivative
        s3 = self.model.get('ik.x')
        sens = ([myokit.Derivative(myokit.Name(s3)),
                 myokit.Name(s2)], [myokit.Name(p1),
                                    myokit.Name(p2)])
        m = myokit.CModel(self.model, sens)
        self.assertTrue(m.has_sensitivities)
        s3 = 'dot(ik.x)'
        sens = ([s3, myokit.Name(s2)], [myokit.Name(p1), myokit.Name(p2)])
        m = myokit.CModel(self.model, sens)
        self.assertTrue(m.has_sensitivities)

        # Sensitivity of derivative of non-state
        sens = ([myokit.Derivative(myokit.Name(s1)),
                 myokit.Name(s2)], [myokit.Name(p1),
                                    myokit.Name(p2)])
        with self.assertRaisesRegex(ValueError, 'Sensitivity of '):
            myokit.CModel(self.model, sens)

        # Sensitivity of bound variable
        sens = (['engine.time',
                 myokit.Name(s2)], [myokit.Name(p1),
                                    myokit.Name(p2)])
        with self.assertRaisesRegex(ValueError, 'Sensitivities cannot'):
            myokit.CModel(self.model, sens)

        # Sensitivity w.r.t. Initial value
        s3 = self.model.get('ik.x')
        sens = ([s3, myokit.Name(s2)],
                [myokit.Name(p1),
                 myokit.InitialValue(myokit.Name(s3))])
        m = myokit.CModel(self.model, sens)
        self.assertTrue(m.has_sensitivities)

        # Sensitivity w.r.t. initial value of non-state
        sens = ([myokit.Name(s1), myokit.Name(s2)],
                [myokit.Name(p1),
                 myokit.InitialValue(myokit.Name(p2))])
        with self.assertRaisesRegex(ValueError, 'Sensitivity with respect to'):
            myokit.CModel(self.model, sens)

        # Sensitivity w.r.t. non-literal
        sens = ([myokit.Name(s1), myokit.Name(s2)], [myokit.Name(p1), 'ik.E'])
        with self.assertRaisesRegex(ValueError, 'Sensitivity with respect to'):
            myokit.CModel(self.model, sens)
Exemplo n.º 8
0
    def _parse_sensitivities(self, model, sensitivities):
        """
        Parses the ``sensitivities`` constructor argument and returns a tuple
        ``(has_sensitivities, dependents, independents)``, where
        ``has_sensitivities`` is a boolean and the other two entries are lists
        containing :class:`myokit.Expression` objects.

        Acceptable input for dependents (y in dy/dx):

        - Variable (state or intermediary)
        - Name or Derivative
        - "ina.INa" or "dot(membrane.V)"

        Acceptable input for independents (x in dy/dx):

        - Variable (literal)
        - Name or InitialValue
        - "ikr.gKr" or "init(membrane.V)"

        The resulting lists contain Expression objects.
        """
        if sensitivities is None:
            return False, [], []

        # Get lists
        try:
            deps, indeps = sensitivities
            deps = list(deps)
            indeps = list(indeps)
        except Exception:
            raise ValueError(
                'The argument `sensitivities` must be None, or a tuple'
                ' containing two lists.')
        if len(deps) == 0 or len(indeps) == 0:
            return False, [], []

        # Create output lists
        dependents = []
        independents = []

        # Check dependents, make sure all are Name or Derivative objects from
        # the cloned model.
        for x in deps:
            deriv = False
            if isinstance(x, myokit.Variable):
                var = model.get(x.qname())
            elif isinstance(x, myokit.Name):
                var = model.get(x.var().qname())
            elif isinstance(x, myokit.Derivative):
                deriv = True
                var = model.get(x.var().qname())
            else:
                x = str(x)
                if x[:4] == 'dot(' and x[-1:] == ')':
                    deriv = True
                    var = x = x[4:-1]
                var = model.get(x)
            lhs = myokit.Name(var)
            if deriv:
                lhs = myokit.Derivative(lhs)
                if not var.is_state():
                    raise ValueError('Sensitivity of ' + lhs.code() +
                                     ' requested, but ' + var.qname() +
                                     ' is not a state variable.')
            elif var.is_bound():
                raise ValueError('Sensitivities cannot be calculated for bound'
                                 ' variables (got ' + str(var.qname()) + ').')
            # Note: constants are fine, just not very useful! But may be
            # easy, e.g. when working with multiple models.
            dependents.append(lhs)

        # Check independents, make sure all are Name or InitialValue
        # objects from the cloned model.
        for x in indeps:
            init = False
            if isinstance(x, myokit.Variable):
                var = model.get(x.qname())
            elif isinstance(x, myokit.Name):
                var = model.get(x.var().qname())
            elif isinstance(x, myokit.InitialValue):
                init = True
                var = model.get(x.var().qname())
            else:
                x = str(x)
                if x[:5] == 'init(' and x[-1:] == ')':
                    init = True
                    x = x[5:-1]
                var = model.get(x)
            lhs = myokit.Name(var)
            if init:
                lhs = myokit.InitialValue(myokit.Name(var))
                if not var.is_state():
                    raise ValueError('Sensitivity with respect to ' +
                                     lhs.code() + ' requested, but ' +
                                     var.qname() + ' is not a'
                                     ' state variable.')
            elif not var.is_literal():
                raise ValueError(
                    'Sensitivity with respect to ' + var.qname() +
                    ' requested, but this is not a literal variable (it'
                    ' depends on other variables).')
            independents.append(lhs)

        return True, dependents, independents
Exemplo n.º 9
0
    def test_model_creation(self):
        # Create a model
        m = myokit.Model('LotkaVolterra')

        # Add the first component
        X = m.add_component('X')
        self.assertEqual(X.qname(), 'X')
        self.assertEqual(X.parent(), m)
        self.assertIsInstance(X, myokit.Component)
        self.assertIn(X.qname(), m)
        self.assertEqual(len(m), 1)

        # Add variable a
        self.assertFalse(X.has_variable('a'))
        a = X.add_variable('a')
        self.assertTrue(X.has_variable('a'))
        self.assertEqual(a, a)
        self.assertIsInstance(a, myokit.Variable)
        self.assertEqual(len(X), 1)
        self.assertIn(a.name(), X)
        a.set_rhs(3)
        self.assertFalse(a.is_state())
        self.assertFalse(a.is_intermediary())
        self.assertTrue(a.is_constant())
        self.assertEqual(a.lhs(), myokit.Name(a))
        self.assertEqual(a.rhs(), myokit.Number(3))
        self.assertEqual(a.rhs().eval(), 3)
        self.assertEqual(a.code(), 'a = 3\n')
        self.assertEqual(a.eq().code(), 'X.a = 3')
        self.assertEqual(a.lhs().code(), 'X.a')
        self.assertEqual(a.rhs().code(), '3')
        self.assertEqual(
            a.eq(), myokit.Equation(myokit.Name(a), myokit.Number(3)))

        # Check lhs
        a_name1 = myokit.Name(a)
        a_name2 = myokit.Name(a)
        self.assertEqual(a_name1, a_name1)
        self.assertEqual(a_name2, a_name2)
        self.assertEqual(a_name1, a_name2)
        self.assertEqual(a_name2, a_name1)

        # Add variable b with two temporary variables
        b = X.add_variable('b')
        self.assertIsInstance(b, myokit.Variable)
        self.assertEqual(len(X), 2)
        self.assertIn(b.name(), X)
        self.assertFalse(b.has_variable('b1'))
        b1 = b.add_variable('b1')
        self.assertTrue(b.has_variable('b1'))
        self.assertEqual(len(b), 1)
        self.assertIn(b1.name(), b)
        self.assertIsInstance(b1, myokit.Variable)
        b2 = b.add_variable('b2')
        self.assertEqual(len(b), 2)
        self.assertIn(b2.name(), b)
        self.assertIsInstance(b2, myokit.Variable)
        b1.set_rhs(1)
        b2.set_rhs(
            myokit.Minus(
                myokit.Minus(myokit.Name(a), myokit.Name(b1)),
                myokit.Number(1))
        )
        b.set_rhs(myokit.Plus(myokit.Name(b1), myokit.Name(b2)))
        self.assertEqual(b.rhs().eval(), 2)
        self.assertFalse(b.is_state())
        self.assertFalse(b.is_intermediary())
        self.assertTrue(b.is_constant())
        self.assertEqual(b.lhs(), myokit.Name(b))

        # Add state variable x
        x = X.add_variable('x')
        x.set_rhs(10)
        x.promote()
        self.assertNotEqual(x, X)
        self.assertIsInstance(x, myokit.Variable)
        self.assertEqual(len(X), 3)
        self.assertIn(x.name(), X)
        self.assertTrue(x.is_state())
        self.assertFalse(x.is_intermediary())
        self.assertFalse(x.is_constant())
        self.assertEqual(x.lhs(), myokit.Derivative(myokit.Name(x)))
        self.assertEqual(x.indice(), 0)

        # Test demoting, promoting
        x.demote()
        self.assertFalse(x.is_state())
        self.assertFalse(x.is_intermediary())
        self.assertTrue(x.is_constant())
        self.assertEqual(x.lhs(), myokit.Name(x))
        x.promote()
        self.assertTrue(x.is_state())
        self.assertFalse(x.is_intermediary())
        self.assertFalse(x.is_constant())
        self.assertEqual(x.lhs(), myokit.Derivative(myokit.Name(x)))
        x.demote()
        x.promote()
        x.demote()
        x.promote()
        self.assertTrue(x.is_state())
        self.assertFalse(x.is_intermediary())
        self.assertFalse(x.is_constant())
        self.assertEqual(x.lhs(), myokit.Derivative(myokit.Name(x)))

        # Add second component, variables
        Y = m.add_component('Y')
        self.assertNotEqual(X, Y)
        self.assertEqual(len(m), 2)
        c = Y.add_variable('c')
        c.set_rhs(myokit.Minus(myokit.Name(a), myokit.Number(1)))
        d = Y.add_variable('d')
        d.set_rhs(2)
        y = Y.add_variable('y')
        y.promote()

        # Set rhs for x and y
        x.set_rhs(myokit.Minus(
            myokit.Multiply(myokit.Name(a), myokit.Name(x)),
            myokit.Multiply(
                myokit.Multiply(myokit.Name(b), myokit.Name(x)),
                myokit.Name(y)
            )
        ))
        x.set_state_value(10)
        self.assertEqual(x.rhs().code(), 'X.a * X.x - X.b * X.x * Y.y')
        y.set_rhs(myokit.Plus(
            myokit.Multiply(
                myokit.PrefixMinus(myokit.Name(c)), myokit.Name(y)
            ),
            myokit.Multiply(
                myokit.Multiply(myokit.Name(d), myokit.Name(x)),
                myokit.Name(y)
            )
        ))
        y.set_state_value(5)
        self.assertEqual(y.rhs().code(), '-Y.c * Y.y + Y.d * X.x * Y.y')

        # Add ano component, variables
        Z = m.add_component('Z')
        self.assertNotEqual(X, Z)
        self.assertNotEqual(Y, Z)
        self.assertEqual(len(m), 3)
        t = Z.add_variable('total')
        self.assertEqual(t.name(), 'total')
        self.assertEqual(t.qname(), 'Z.total')
        self.assertEqual(t.qname(X), 'Z.total')
        self.assertEqual(t.qname(Z), 'total')
        t.set_rhs(myokit.Plus(myokit.Name(x), myokit.Name(y)))
        self.assertFalse(t.is_state())
        self.assertFalse(t.is_constant())
        self.assertTrue(t.is_intermediary())
        self.assertEqual(t.rhs().code(), 'X.x + Y.y')
        self.assertEqual(t.rhs().code(X), 'x + Y.y')
        self.assertEqual(t.rhs().code(Y), 'X.x + y')
        self.assertEqual(t.rhs().code(Z), 'X.x + Y.y')

        # Add engine component
        E = m.add_component('engine')
        self.assertNotEqual(X, E)
        self.assertNotEqual(Y, E)
        self.assertNotEqual(Z, E)
        self.assertEqual(len(m), 4)
        time = E.add_variable('time')
        time.set_rhs(0)
        self.assertIsNone(time.binding())
        time.set_binding('time')
        self.assertIsNotNone(time.binding())

        # Check state
        state = [i for i in m.states()]
        self.assertEqual(len(state), 2)
        self.assertIn(x, state)
        self.assertIn(y, state)

        # Test variable iterators
        def has(*v):
            for var in v:
                self.assertIn(var, vrs)
            self.assertEqual(len(vrs), len(v))
        vrs = [i for i in m.variables()]
        has(a, b, c, d, x, y, t, time)
        vrs = [i for i in m.variables(deep=True)]
        has(a, b, c, d, x, y, t, b1, b2, time)
        vrs = [i for i in m.variables(const=True)]
        has(a, b, c, d)
        vrs = [i for i in m.variables(const=True, deep=True)]
        has(a, b, c, d, b1, b2)
        vrs = [i for i in m.variables(const=False)]
        has(x, y, t, time)
        vrs = [i for i in m.variables(const=False, deep=True)]
        has(x, y, t, time)
        vrs = [i for i in m.variables(state=True)]
        has(x, y)
        vrs = [i for i in m.variables(state=True, deep=True)]
        has(x, y)
        vrs = [i for i in m.variables(state=False)]
        has(a, b, c, d, t, time)
        vrs = [i for i in m.variables(state=False, deep=True)]
        has(a, b, c, d, t, b1, b2, time)
        vrs = [i for i in m.variables(inter=True)]
        has(t)
        vrs = [i for i in m.variables(inter=True, deep=True)]
        has(t)
        vrs = [i for i in m.variables(inter=False)]
        has(a, b, c, d, x, y, time)
        vrs = [i for i in m.variables(inter=False, deep=True)]
        has(a, b, c, d, x, y, b1, b2, time)
        vrs = list(m.variables(const=True, state=True))
        has()
        vrs = list(m.variables(const=True, state=False))
        has(a, b, c, d)

        # Test sorted variable iteration
        names = [v.name() for v in m.variables(deep=True, sort=True)]
        self.assertEqual(names, [
            'a', 'b', 'b1', 'b2', 'x', 'c', 'd', 'y', 'total', 'time'])

        # Test equation iteration
        # Deeper testing is done when testing the ``variables`` method.
        eq = [eq for eq in X.equations(deep=False)]
        self.assertEqual(len(eq), 3)
        self.assertEqual(len(eq), X.count_equations(deep=False))
        eq = [eq for eq in X.equations(deep=True)]
        self.assertEqual(len(eq), 5)
        self.assertEqual(len(eq), X.count_equations(deep=True))
        eq = [eq for eq in Y.equations(deep=False)]
        self.assertEqual(len(eq), 3)
        self.assertEqual(len(eq), Y.count_equations(deep=False))
        eq = [eq for eq in Y.equations(deep=True)]
        self.assertEqual(len(eq), 3)
        self.assertEqual(len(eq), Y.count_equations(deep=True))
        eq = [eq for eq in Z.equations(deep=False)]
        self.assertEqual(len(eq), 1)
        self.assertEqual(len(eq), Z.count_equations(deep=False))
        eq = [eq for eq in Z.equations(deep=True)]
        self.assertEqual(len(eq), 1)
        self.assertEqual(len(eq), Z.count_equations(deep=True))
        eq = [eq for eq in E.equations(deep=False)]
        self.assertEqual(len(eq), 1)
        eq = [eq for eq in E.equations(deep=True)]
        self.assertEqual(len(eq), 1)
        eq = [eq for eq in m.equations(deep=False)]
        self.assertEqual(len(eq), 8)
        eq = [eq for eq in m.equations(deep=True)]
        self.assertEqual(len(eq), 10)

        # Test dependency mapping
        def has(var, *dps):
            lst = vrs[m.get(var).lhs() if isinstance(var, basestring) else var]
            self.assertEqual(len(lst), len(dps))
            for d in dps:
                d = m.get(d).lhs() if isinstance(d, basestring) else d
                self.assertIn(d, lst)

        vrs = m.map_shallow_dependencies(omit_states=False)
        self.assertEqual(len(vrs), 12)
        has('X.a')
        has('X.b', 'X.b.b1', 'X.b.b2')
        has('X.b.b1')
        has('X.b.b2', 'X.a', 'X.b.b1')
        has('X.x', 'X.a', 'X.b', myokit.Name(x), myokit.Name(y))
        has(myokit.Name(x))
        has('Y.c', 'X.a')
        has('Y.d')
        has('Y.y', 'Y.c', 'Y.d', myokit.Name(x), myokit.Name(y))
        has(myokit.Name(y))
        has('Z.total', myokit.Name(x), myokit.Name(y))
        vrs = m.map_shallow_dependencies()
        self.assertEqual(len(vrs), 10)
        has('X.a')
        has('X.b', 'X.b.b1', 'X.b.b2')
        has('X.b.b1')
        has('X.b.b2', 'X.a', 'X.b.b1')
        has('X.x', 'X.a', 'X.b')
        has('Y.c', 'X.a')
        has('Y.d')
        has('Y.y', 'Y.c', 'Y.d')
        has('Z.total')
        vrs = m.map_shallow_dependencies(collapse=True)
        self.assertEqual(len(vrs), 8)
        has('X.a')
        has('X.b', 'X.a')
        has('X.x', 'X.a', 'X.b')
        has('Y.c', 'X.a')
        has('Y.d')
        has('Y.y', 'Y.c', 'Y.d')
        has('Z.total')

        # Validate
        m.validate()

        # Get solvable order
        order = m.solvable_order()
        self.assertEqual(len(order), 5)
        self.assertIn('*remaining*', order)
        self.assertIn('X', order)
        self.assertIn('Y', order)
        self.assertIn('Z', order)

        # Check that X comes before Y
        pos = dict([(name, k) for k, name in enumerate(order)])
        self.assertLess(pos['X'], pos['Y'])
        self.assertEqual(pos['*remaining*'], 4)

        # Check component equation lists
        eqs = order['*remaining*']
        self.assertEqual(len(eqs), 0)
        eqs = order['Z']
        self.assertEqual(len(eqs), 1)
        self.assertEqual(eqs[0].code(), 'Z.total = X.x + Y.y')
        eqs = order['Y']
        self.assertEqual(len(eqs), 3)
        self.assertEqual(
            eqs[2].code(), 'dot(Y.y) = -Y.c * Y.y + Y.d * X.x * Y.y')
        eqs = order['X']
        self.assertEqual(len(eqs), 5)
        self.assertEqual(eqs[0].code(), 'X.a = 3')
        self.assertEqual(eqs[1].code(), 'b1 = 1')
        self.assertEqual(eqs[2].code(), 'b2 = X.a - b1 - 1')
        self.assertEqual(eqs[3].code(), 'X.b = b1 + b2')

        # Test model export and cloning
        code1 = m.code()
        code2 = m.clone().code()
        self.assertEqual(code1, code2)
Exemplo n.º 10
0
    def test_remove_variable(self):
        # Test the removal of a variable.

        # Create a model
        m = myokit.Model('LotkaVolterra')

        # Add a variable 'a'
        X = m.add_component('X')

        # Simplest case
        a = X.add_variable('a')
        self.assertEqual(X.count_variables(), 1)
        X.remove_variable(a)
        self.assertEqual(X.count_variables(), 0)
        self.assertRaises(Exception, X.remove_variable, a)

        # Test re-adding
        a = X.add_variable('a')
        a.set_rhs(myokit.Number(5))
        self.assertEqual(X.count_variables(), 1)

        # Test deleting dependent variables
        b = X.add_variable('b')
        self.assertEqual(X.count_variables(), 2)
        b.set_rhs(myokit.Plus(myokit.Number(3), myokit.Name(a)))

        # Test blocking of removal
        self.assertRaises(myokit.IntegrityError, X.remove_variable, a)
        self.assertEqual(X.count_variables(), 2)

        # Test removal in the right order
        X.remove_variable(b)
        self.assertEqual(X.count_variables(), 1)
        X.remove_variable(a)
        self.assertEqual(X.count_variables(), 0)

        # Test reference to current state variable values
        a = X.add_variable('a')
        a.set_rhs(myokit.Number(5))
        a.promote()
        b = X.add_variable('b')
        b.set_rhs(myokit.Plus(myokit.Number(3), myokit.Name(a)))
        self.assertRaises(myokit.IntegrityError, X.remove_variable, a)
        X.remove_variable(b)
        X.remove_variable(a)
        self.assertEqual(X.count_variables(), 0)

        # Test reference to current state variable values with "self"-ref
        a = X.add_variable('a')
        a.promote()
        a.set_rhs(myokit.Name(a))
        X.remove_variable(a)

        # Test it doesn't interfere with normal workings
        a = X.add_variable('a')
        a.promote()
        a.set_rhs(myokit.Name(a))
        b = X.add_variable('b')
        b.set_rhs(myokit.Name(a))
        self.assertRaises(myokit.IntegrityError, X.remove_variable, a)
        X.remove_variable(b)
        X.remove_variable(a)

        # Test reference to dot
        a = X.add_variable('a')
        a.set_rhs(myokit.Number(5))
        a.promote()
        b = X.add_variable('b')
        b.set_rhs(myokit.Derivative(myokit.Name(a)))
        self.assertRaises(myokit.IntegrityError, X.remove_variable, a)
        X.remove_variable(b)
        X.remove_variable(a)

        # Test if orphaned
        self.assertIsNone(b.parent())

        # Test deleting variable with nested variables
        a = X.add_variable('a')
        b = a.add_variable('b')
        b.set_rhs(myokit.Plus(myokit.Number(2), myokit.Number(2)))
        a.set_rhs(myokit.Multiply(myokit.Number(3), myokit.Name(b)))
        self.assertRaises(myokit.IntegrityError, X.remove_variable, a)
        self.assertEqual(a.count_variables(), 1)
        self.assertEqual(X.count_variables(), 1)

        # Test recursive deleting
        X.remove_variable(a, recursive=True)
        self.assertEqual(a.count_variables(), 0)
        self.assertEqual(X.count_variables(), 0)

        # Test deleting variable with nested variables that depend on each
        # other
        a = X.add_variable('a')
        b = a.add_variable('b')
        c = a.add_variable('c')
        d = a.add_variable('d')
        a.set_rhs('b + c - d')
        a.promote(0.1)
        b.set_rhs('2 * a - d')
        c.set_rhs('a + b + d')
        d.set_rhs('3 * a')
        self.assertRaises(myokit.IntegrityError, X.remove_variable, a)
        self.assertEqual(a.count_variables(), 3)
        self.assertEqual(X.count_variables(), 1)
        X.remove_variable(a, recursive=True)
        self.assertEqual(a.count_variables(), 0)
        self.assertEqual(X.count_variables(), 0)

        # Test if removed from model's label and binding lists
        m = myokit.Model()
        c = m.add_component('c')
        x = c.add_variable('x')
        y = c.add_variable('y')
        x.set_rhs(0)
        y.set_rhs(0)
        x.set_binding('time')
        y.set_label('membrane_potential')
        self.assertIs(m.binding('time'), x)
        self.assertIs(m.label('membrane_potential'), y)
        c.remove_variable(x)
        self.assertIs(m.binding('time'), None)
        self.assertIs(m.label('membrane_potential'), y)
        c.remove_variable(y)
        self.assertIs(m.binding('time'), None)
        self.assertIs(m.label('membrane_potential'), None)
Exemplo n.º 11
0
    def test_reader_writer(self):
        # Test using the proper reader/writer
        try:
            import sympy as sp
        except ImportError:
            print('Sympy not found, skipping test.')
            return

        # Create writer and reader
        w = mypy.SymPyExpressionWriter()
        r = mypy.SymPyExpressionReader(self._model)

        # Name
        a = self._a
        ca = sp.Symbol('c.a')
        self.assertEqual(w.ex(a), ca)
        self.assertEqual(r.ex(ca), a)

        # Number with unit
        b = myokit.Number('12', 'pF')
        cb = sp.Float(12)
        self.assertEqual(w.ex(b), cb)
        # Note: Units are lost in sympy im/ex-port!
        #self.assertEqual(r.ex(cb), b)

        # Number without unit
        b = myokit.Number('12')
        cb = sp.Float(12)
        self.assertEqual(w.ex(b), cb)
        self.assertEqual(r.ex(cb), b)

        # Prefix plus
        x = myokit.PrefixPlus(b)
        self.assertEqual(w.ex(x), cb)
        # Note: Sympy doesn't seem to have a prefix plus
        self.assertEqual(r.ex(cb), b)

        # Prefix minus
        # Note: SymPy treats -x as Mul(NegativeOne, x)
        # But for numbers just returns a number with a negative value
        x = myokit.PrefixMinus(b)
        self.assertEqual(w.ex(x), -cb)
        self.assertEqual(float(r.ex(-cb)), float(x))

        # Plus
        x = myokit.Plus(a, b)
        self.assertEqual(w.ex(x), ca + cb)
        # Note: SymPy likes to re-order the operands...
        self.assertEqual(float(r.ex(ca + cb)), float(x))

        # Minus
        x = myokit.Minus(a, b)
        self.assertEqual(w.ex(x), ca - cb)
        self.assertEqual(float(r.ex(ca - cb)), float(x))

        # Multiply
        x = myokit.Multiply(a, b)
        self.assertEqual(w.ex(x), ca * cb)
        self.assertEqual(float(r.ex(ca * cb)), float(x))

        # Divide
        x = myokit.Divide(a, b)
        self.assertEqual(w.ex(x), ca / cb)
        self.assertEqual(float(r.ex(ca / cb)), float(x))

        # Quotient
        x = myokit.Quotient(a, b)
        self.assertEqual(w.ex(x), ca // cb)
        self.assertEqual(float(r.ex(ca // cb)), float(x))

        # Remainder
        x = myokit.Remainder(a, b)
        self.assertEqual(w.ex(x), ca % cb)
        self.assertEqual(float(r.ex(ca % cb)), float(x))

        # Power
        x = myokit.Power(a, b)
        self.assertEqual(w.ex(x), ca**cb)
        self.assertEqual(float(r.ex(ca**cb)), float(x))

        # Sqrt
        x = myokit.Sqrt(a)
        cx = sp.sqrt(ca)
        self.assertEqual(w.ex(x), cx)
        # Note: SymPy converts sqrt to power
        self.assertEqual(float(r.ex(cx)), float(x))

        # Exp
        x = myokit.Exp(a)
        cx = sp.exp(ca)
        self.assertEqual(w.ex(x), cx)
        self.assertEqual(r.ex(cx), x)

        # Log(a)
        x = myokit.Log(a)
        cx = sp.log(ca)
        self.assertEqual(w.ex(x), cx)
        self.assertEqual(r.ex(cx), x)

        # Log(a, b)
        x = myokit.Log(a, b)
        cx = sp.log(ca, cb)
        self.assertEqual(w.ex(x), cx)
        self.assertEqual(float(r.ex(cx)), float(x))

        # Log10
        x = myokit.Log10(b)
        cx = sp.log(cb, 10)
        self.assertEqual(w.ex(x), cx)
        self.assertAlmostEqual(float(r.ex(cx)), float(x))

        # Sin
        x = myokit.Sin(a)
        cx = sp.sin(ca)
        self.assertEqual(w.ex(x), cx)
        self.assertEqual(r.ex(cx), x)

        # Cos
        x = myokit.Cos(a)
        cx = sp.cos(ca)
        self.assertEqual(w.ex(x), cx)
        self.assertEqual(r.ex(cx), x)

        # Tan
        x = myokit.Tan(a)
        cx = sp.tan(ca)
        self.assertEqual(w.ex(x), cx)
        self.assertEqual(r.ex(cx), x)

        # ASin
        x = myokit.ASin(a)
        cx = sp.asin(ca)
        self.assertEqual(w.ex(x), cx)
        self.assertEqual(r.ex(cx), x)

        # ACos
        x = myokit.ACos(a)
        cx = sp.acos(ca)
        self.assertEqual(w.ex(x), cx)
        self.assertEqual(r.ex(cx), x)

        # ATan
        x = myokit.ATan(a)
        cx = sp.atan(ca)
        self.assertEqual(w.ex(x), cx)
        self.assertEqual(r.ex(cx), x)

        # Floor
        x = myokit.Floor(a)
        cx = sp.floor(ca)
        self.assertEqual(w.ex(x), cx)
        self.assertEqual(r.ex(cx), x)

        # Ceil
        x = myokit.Ceil(a)
        cx = sp.ceiling(ca)
        self.assertEqual(w.ex(x), cx)
        self.assertEqual(r.ex(cx), x)

        # Abs
        x = myokit.Abs(a)
        cx = sp.Abs(ca)
        self.assertEqual(w.ex(x), cx)
        self.assertEqual(r.ex(cx), x)

        # Equal
        x = myokit.Equal(a, b)
        cx = sp.Eq(ca, cb)
        self.assertEqual(w.ex(x), cx)
        self.assertEqual(r.ex(cx), x)

        # NotEqual
        x = myokit.NotEqual(a, b)
        cx = sp.Ne(ca, cb)
        self.assertEqual(w.ex(x), cx)
        self.assertEqual(r.ex(cx), x)

        # More
        x = myokit.More(a, b)
        cx = sp.Gt(ca, cb)
        self.assertEqual(w.ex(x), cx)
        self.assertEqual(r.ex(cx), x)

        # Less
        x = myokit.Less(a, b)
        cx = sp.Lt(ca, cb)
        self.assertEqual(w.ex(x), cx)
        self.assertEqual(r.ex(cx), x)

        # MoreEqual
        x = myokit.MoreEqual(a, b)
        cx = sp.Ge(ca, cb)
        self.assertEqual(w.ex(x), cx)
        self.assertEqual(r.ex(cx), x)

        # LessEqual
        x = myokit.LessEqual(a, b)
        cx = sp.Le(ca, cb)
        self.assertEqual(w.ex(x), cx)
        self.assertEqual(r.ex(cx), x)

        # Not
        x = myokit.Not(a)
        cx = sp.Not(ca)
        self.assertEqual(w.ex(x), cx)
        self.assertEqual(r.ex(cx), x)

        # And
        cond1 = myokit.More(a, b)
        cond2 = myokit.Less(a, b)
        c1 = sp.Gt(ca, cb)
        c2 = sp.Lt(ca, cb)

        x = myokit.And(cond1, cond2)
        cx = sp.And(c1, c2)
        self.assertEqual(w.ex(x), cx)
        self.assertEqual(r.ex(cx), x)

        # Or
        x = myokit.Or(cond1, cond2)
        cx = sp.Or(c1, c2)
        self.assertEqual(w.ex(x), cx)
        self.assertEqual(r.ex(cx), x)

        # If
        # Note: sympy only does piecewise, not if
        x = myokit.If(cond1, a, b)
        cx = sp.Piecewise((ca, c1), (cb, True))
        self.assertEqual(w.ex(x), cx)
        self.assertEqual(r.ex(cx), x.piecewise())

        # Piecewise
        c = myokit.Number(1)
        cc = sp.Float(1)
        x = myokit.Piecewise(cond1, a, cond2, b, c)
        cx = sp.Piecewise((ca, c1), (cb, c2), (cc, True))
        self.assertEqual(w.ex(x), cx)
        self.assertEqual(r.ex(cx), x)

        # Myokit piecewise's (like CellML's) always have a final True
        # condition (i.e. an 'else'). SymPy doesn't require this, so test if
        # we can import this --> It will add an "else 0"
        x = myokit.Piecewise(cond1, a, myokit.Number(0))
        cx = sp.Piecewise((ca, c1))
        self.assertEqual(r.ex(cx), x)

        # SymPy function without Myokit equivalent --> Should raise exception
        cu = sp.principal_branch(cx, cc)
        self.assertRaisesRegex(ValueError, 'Unsupported type', r.ex, cu)

        # Derivative
        m = self._model.clone()
        avar = m.get('c.a')
        r = mypy.SymPyExpressionReader(self._model)
        avar.promote(4)
        x = myokit.Derivative(self._a)
        cx = sp.symbols('dot(c.a)')
        self.assertEqual(w.ex(x), cx)
        self.assertEqual(r.ex(cx), x)

        # Equation
        e = myokit.Equation(a, b)
        ce = sp.Eq(ca, cb)
        self.assertEqual(w.eq(e), ce)
        # There's no backwards equivalent for this!
        # The ereader can handle it, but it becomes and Equals expression.

        # Test sympy division
        del (m, avar, x, cx, e, ce)
        a = self._model.get('c.a')
        b = self._model.get('c').add_variable('bbb')
        b.set_rhs('1 / a')
        e = b.rhs()
        ce = w.ex(b.rhs())
        e = r.ex(ce)
        self.assertEqual(
            e,
            myokit.Multiply(myokit.Number(1),
                            myokit.Power(myokit.Name(a), myokit.Number(-1))))

        # Test sympy negative numbers
        a = self._model.get('c.a')
        e1 = myokit.PrefixMinus(myokit.Name(a))
        ce = w.ex(e1)
        e2 = r.ex(ce)
        self.assertEqual(e1, e2)
Exemplo n.º 12
0
    def test_all(self):
        w = myokit.formats.python.PythonExpressionWriter()

        model = myokit.Model()
        component = model.add_component('c')
        avar = component.add_variable('a')

        # Name
        a = myokit.Name(avar)
        self.assertEqual(w.ex(a), 'c.a')
        # Derivative
        x = myokit.Derivative(a)
        self.assertEqual(w.ex(x), 'dot(c.a)')
        # Partial derivative
        x = myokit.PartialDerivative(a, a)
        self.assertEqual(w.ex(x), 'diff(c.a, c.a)')
        # Initial value
        x = myokit.InitialValue(a)
        self.assertEqual(w.ex(x), 'init(c.a)')

        # Number
        b = myokit.Number(3)
        self.assertEqual(w.ex(b), '3.0')
        # Number with unit
        b = myokit.Number(12, 'pF')
        self.assertEqual(w.ex(b), '12.0')

        # Prefix plus
        x = myokit.PrefixPlus(b)
        self.assertEqual(w.ex(x), '12.0')
        # Prefix minus
        x = myokit.PrefixMinus(b)
        self.assertEqual(w.ex(x), '(-12.0)')

        # Plus
        x = myokit.Plus(a, b)
        self.assertEqual(w.ex(x), 'c.a + 12.0')
        # Minus
        x = myokit.Minus(a, b)
        self.assertEqual(w.ex(x), 'c.a - 12.0')
        # Multiply
        x = myokit.Multiply(a, b)
        self.assertEqual(w.ex(x), 'c.a * 12.0')
        # Divide
        x = myokit.Divide(a, b)
        self.assertEqual(w.ex(x), 'c.a / 12.0')

        # Quotient
        x = myokit.Quotient(a, b)
        self.assertEqual(w.ex(x), 'c.a // 12.0')
        # Remainder
        x = myokit.Remainder(a, b)
        self.assertEqual(w.ex(x), 'c.a % 12.0')

        # Power
        x = myokit.Power(a, b)
        self.assertEqual(w.ex(x), 'c.a ** 12.0')
        # Sqrt
        x = myokit.Sqrt(b)
        self.assertEqual(w.ex(x), 'math.sqrt(12.0)')
        # Exp
        x = myokit.Exp(a)
        self.assertEqual(w.ex(x), 'math.exp(c.a)')
        # Log(a)
        x = myokit.Log(b)
        self.assertEqual(w.ex(x), 'math.log(12.0)')
        # Log(a, b)
        x = myokit.Log(a, b)
        self.assertEqual(w.ex(x), 'math.log(c.a, 12.0)')
        # Log10
        x = myokit.Log10(b)
        self.assertEqual(w.ex(x), 'math.log10(12.0)')

        # Sin
        x = myokit.Sin(b)
        self.assertEqual(w.ex(x), 'math.sin(12.0)')
        # Cos
        x = myokit.Cos(b)
        self.assertEqual(w.ex(x), 'math.cos(12.0)')
        # Tan
        x = myokit.Tan(b)
        self.assertEqual(w.ex(x), 'math.tan(12.0)')
        # ASin
        x = myokit.ASin(b)
        self.assertEqual(w.ex(x), 'math.asin(12.0)')
        # ACos
        x = myokit.ACos(b)
        self.assertEqual(w.ex(x), 'math.acos(12.0)')
        # ATan
        x = myokit.ATan(b)
        self.assertEqual(w.ex(x), 'math.atan(12.0)')

        # Floor
        x = myokit.Floor(b)
        self.assertEqual(w.ex(x), 'math.floor(12.0)')
        # Ceil
        x = myokit.Ceil(b)
        self.assertEqual(w.ex(x), 'math.ceil(12.0)')
        # Abs
        x = myokit.Abs(b)
        self.assertEqual(w.ex(x), 'abs(12.0)')

        # Equal
        x = myokit.Equal(a, b)
        self.assertEqual(w.ex(x), '(c.a == 12.0)')
        # NotEqual
        x = myokit.NotEqual(a, b)
        self.assertEqual(w.ex(x), '(c.a != 12.0)')
        # More
        x = myokit.More(a, b)
        self.assertEqual(w.ex(x), '(c.a > 12.0)')
        # Less
        x = myokit.Less(a, b)
        self.assertEqual(w.ex(x), '(c.a < 12.0)')
        # MoreEqual
        x = myokit.MoreEqual(a, b)
        self.assertEqual(w.ex(x), '(c.a >= 12.0)')
        # LessEqual
        x = myokit.LessEqual(a, b)
        self.assertEqual(w.ex(x), '(c.a <= 12.0)')

        # Not
        cond1 = myokit.parse_expression('5 > 3')
        cond2 = myokit.parse_expression('2 < 1')
        x = myokit.Not(cond1)
        self.assertEqual(w.ex(x), 'not ((5.0 > 3.0))')
        # And
        x = myokit.And(cond1, cond2)
        self.assertEqual(w.ex(x), '((5.0 > 3.0) and (2.0 < 1.0))')
        # Or
        x = myokit.Or(cond1, cond2)
        self.assertEqual(w.ex(x), '((5.0 > 3.0) or (2.0 < 1.0))')

        # If
        x = myokit.If(cond1, a, b)
        self.assertEqual(w.ex(x), '(c.a if (5.0 > 3.0) else 12.0)')
        # Piecewise
        c = myokit.Number(1)
        x = myokit.Piecewise(cond1, a, cond2, b, c)
        self.assertEqual(
            w.ex(x),
            '(c.a if (5.0 > 3.0) else (12.0 if (2.0 < 1.0) else 1.0))')

        # Test fetching using ewriter method
        w = myokit.formats.ewriter('python')
        self.assertIsInstance(w, myokit.formats.python.PythonExpressionWriter)

        # Test lhs method
        w.set_lhs_function(lambda x: 'sheep')
        self.assertEqual(w.ex(a), 'sheep')

        # Test without a Myokit expression
        self.assertRaisesRegex(
            ValueError, 'Unknown expression type', w.ex, 7)
Exemplo n.º 13
0
    def test_derivatives(self):
        # Test parsing of derivatives

        # Basic derivative
        x = (
            '<apply>'
            '  <diff/>'
            '  <bvar>'
            '    <ci>time</ci>'
            '  </bvar>'
            '  <ci>V</ci>'
            '</apply>'
        )
        e = myokit.Derivative(myokit.Name('V'))
        self.assertEqual(self.p(x), e)

        # Derivative with degree element
        x = (
            '<apply>'
            '  <diff/>'
            '  <bvar>'
            '    <ci>time</ci>'
            '    <degree><cn>1.0</cn></degree>'
            '  </bvar>'
            '  <ci>V</ci>'
            '</apply>'
        )
        e = myokit.Derivative(myokit.Name('V'))
        self.assertEqual(self.p(x), e)

        # Derivative with degree element other than 1
        self.assertRaisesRegex(
            mathml.MathMLError, 'degree one',
            self.p,
            '<apply>'
            '  <diff/>'
            '  <bvar>'
            '    <ci>time</ci>'
            '    <degree><cn>2</cn></degree>'
            '  </bvar>'
            '  <ci>V</ci>'
            '</apply>'
        )

        # Derivative with degree element but no cn
        self.assertRaisesRegex(
            mathml.MathMLError, '<degree> element must contain a <cn>',
            self.p,
            '<apply>'
            '  <diff/>'
            '  <bvar>'
            '    <ci>time</ci>'
            '    <degree/>'
            '  </bvar>'
            '  <ci>V</ci>'
            '</apply>'
        )

        # Derivative of an expression
        self.assertRaisesRegex(
            mathml.MathMLError, '<diff> element must contain a <ci>', self.p,
            '<apply>'
            '  <diff/>'
            '  <bvar>'
            '    <ci>time</ci>'
            '  </bvar>'
            '  <apply><plus/><cn>1.0</cn><ci>V</ci></apply>'
            '</apply>'
        )

        # Derivative without a bvar
        self.assertRaisesRegex(
            mathml.MathMLError, '<diff> element must contain a <bvar>', self.p,
            '<apply>'
            '  <diff/>'
            '  <ci>x</ci>'
            '</apply>'
        )

        # Derivative without ci in its bvar
        self.assertRaisesRegex(
            mathml.MathMLError, '<bvar> element must contain a <ci>', self.p,
            '<apply>'
            '  <diff/>'
            '  <bvar/>'
            '  <ci>x</ci>'
            '</apply>'
        )