Example #1
0
 def unused_and_cycles(self):
     """
     Test unused variable and cycle detection.
     """
     m = Model('LotkaVolterra')
     c0 = m.add_component('c0')
     t = c0.add_variable('time')
     t.set_rhs(Number(0))
     t.set_binding('time')
     c1 = m.add_component('c1')
     m.add_component('c2')
     c1_a = c1.add_variable('a')
     c1_b = c1.add_variable('b')
     c1_a.promote(1.0)
     c1_a.set_rhs(Multiply(Name(c1_a), Number(0.5)))
     c1_b.set_rhs(Multiply(Name(c1_a), Number(1.0)))
     # b is unused, test if found
     m.validate()
     w = m.warnings()
     self.assertEqual(len(w), 1)
     self.assertEqual(type(w[0]), UnusedVariableError)
     # b is used by c, c is unused, test if found
     c1_c = c1.add_variable('c')
     c1_c.set_rhs(Name(c1_b))
     m.validate()
     w = m.warnings()
     self.assertEqual(len(w), 2)
     self.assertEqual(type(w[0]), UnusedVariableError)
     self.assertEqual(type(w[1]), UnusedVariableError)
     # Test 1:1 cycle
     c1_b.set_rhs(Name(c1_b))
     self.assertRaises(CyclicalDependencyError, m.validate)
     # Test longer cycles
     c1_b.set_rhs(Multiply(Number(10), Name(c1_c)))
     self.assertRaises(CyclicalDependencyError, m.validate)
     # Reset
     c1_b.set_rhs(Multiply(Name(c1_a), Number(1.0)))
     m.validate()
     # Test cycle involving state variable
     c1_a.set_rhs(Name(c1_b))
     m.validate()
     c1_b.set_rhs(Multiply(Name(c1_a), Name(c1_b)))
     self.assertRaises(CyclicalDependencyError, m.validate)
     c1_b.set_rhs(Multiply(Name(c1_a), Name(c1_c)))
     c1_c.set_rhs(Multiply(Name(c1_a), Number(3)))
     m.validate()
     w = m.warnings()
     self.assertEqual(len(w), 0)
     c1_c.set_rhs(Multiply(Name(c1_a), Name(c1_b)))
     self.assertRaises(CyclicalDependencyError, m.validate)
Example #2
0
 def move_variable(self):
     """
     Tests the method to move component variables to another component.
     """
     # Create a model
     m = Model('LotkaVolterra')
     X = m.add_component('X')
     a = X.add_variable('a')
     a.set_rhs(3)
     b = X.add_variable('b')
     b1 = b.add_variable('b1')
     b2 = b.add_variable('b2')
     b1.set_rhs(1)
     b2.set_rhs(Minus(Minus(Name(a), Name(b1)), Number(1)))
     b.set_rhs(Plus(Name(b1), Name(b2)))
     x = X.add_variable('x')
     x.promote()
     Y = m.add_component('Y')
     c = Y.add_variable('c')
     c.set_rhs(Minus(Name(a), Number(1)))
     d = Y.add_variable('d')
     d.set_rhs(2)
     y = Y.add_variable('y')
     y.promote()
     x.set_rhs(
         Minus(Multiply(Name(a), Name(x)),
               Multiply(Multiply(Name(b), Name(x)), Name(y))))
     x.set_state_value(10)
     y.set_rhs(
         Plus(Multiply(PrefixMinus(Name(c)), Name(y)),
              Multiply(Multiply(Name(d), Name(x)), Name(y))))
     y.set_state_value(5)
     Z = m.add_component('Z')
     t = Z.add_variable('total')
     t.set_rhs(Plus(Name(x), Name(y)))
     E = m.add_component('engine')
     time = E.add_variable('time')
     time.set_rhs(0)
     time.set_binding('time')
     # Move time variable into X
     m.validate()  # If not valid, this will raise an exception
     E.move_variable(time, Z)
     m.validate()
Example #3
0
    def _parse(self, path, model):
        """
        Parses a ChannelML channel and adds it to the given model.

        Returns the new :class:`myokit.Component`.
        """
        # Check model: get membrane potential varialbe
        vvar = model.label('membrane_potential')
        if vvar is None:
            raise ChannelMLError(
                'No variable labelled "membrane_potential" was found. This is'
                ' required when adding ChannelML channels to existing models.')

        # Parse XML
        path = os.path.abspath(os.path.expanduser(path))
        dom = xml.dom.minidom.parse(path)

        # Get channelml tag
        root = dom.getElementsByTagName('channelml')
        try:
            root = root[0]
        except IndexError:
            raise ChannelMLError(
                'Unknown root element in xml document. Expecting a tag of type'
                ' <channelml>.')

        # Extract meta data
        meta = self._rip_meta(root)

        # Get channeltype tag
        root = root.getElementsByTagName('channel_type')
        try:
            root = root[0]
        except IndexError:
            raise ChannelMLError(
                'No <channel_type> element found inside <channelml> element.'
                ' Import of <synapse_type> and <ion_concentration> is not'
                ' supported.')

        # Add channel component
        name = self._sanitise_name(root.getAttribute('name'))
        if name in model:
            name_root = name
            i = 2
            while name in model:
                name = name_root + '_' + str(i)
                i += 1
        component = model.add_component(name)

        # Add alias to membrane potential
        component.add_alias('v', vvar)

        # Add meta-data
        component.meta['desc'] = meta

        # Find current-voltage relation
        cvr = root.getElementsByTagName('current_voltage_relation')
        if len(cvr) < 1:
            raise ChannelMLError(
                'Channel model must contain a current voltage relation.')
        elif len(cvr) > 1:
            warnings.warn(
                'Multiple current voltage relations found, ignoring all but'
                ' first.')
        cvr = cvr[0]

        # Check for q10
        try:
            q10 = cvr.getElementsByTagName('q10_settings')[0]
            component.meta['experimental_temperature'] = str(
                q10.getAttribute('experimental_temp'))
        except IndexError:
            pass

        # Add reversal potential
        E = 0
        if cvr.hasAttribute('default_erev'):
            E = float(cvr.getAttribute('default_erev'))
        evar = component.add_variable('E')
        evar.meta['desc'] = 'Reversal potential'
        evar.set_rhs(E)

        # Get maximum conductance
        gmax = 1.0
        if cvr.hasAttribute('default_gmax'):
            gmax = float(cvr.getAttribute('default_gmax'))
        gmaxvar = component.add_variable('gmax')
        gmaxvar.set_rhs(gmax)
        gmaxvar.meta['desc'] = 'Maximum conductance'

        # Add gates
        gvars = []
        for gate in cvr.getElementsByTagName('gate'):
            gname = self._sanitise_name(gate.getAttribute('name'))
            gvar = component.add_variable(gname)
            gvar.promote(0)
            cstate = gate.getElementsByTagName('closed_state')
            cstate = cstate[0].getAttribute('id')
            ostate = gate.getElementsByTagName('open_state')
            ostate = ostate[0].getAttribute('id')

            # Transitions
            trans = gate.getElementsByTagName('transition')
            if len(trans) > 0:
                # Use "transitions" definition
                if len(trans) != 2:
                    raise ChannelMLError(
                        'Expecting exactly 2 transitions for gate <' + gname +
                        '>.')

                # Get closed-to-open state
                tco = None
                for t in trans:
                    if t.getAttribute('to') == ostate and \
                            t.getAttribute('from') == cstate:
                        tco = t
                        break
                if tco is None:
                    raise ChannelMLError(
                        'Unable to find closed-to-open transition for gate <' +
                        gname + '>')

                # Get open-to-closed state
                toc = None
                for t in trans:
                    if t.getAttribute('to') == cstate and \
                            t.getAttribute('from') == ostate:
                        toc = t
                        break
                if toc is None:
                    raise ChannelMLError(
                        'Unable to find open-to-closed transition for gate <' +
                        gname + '>')

                # Add closed-to-open transition
                tname = self._sanitise_name(tco.getAttribute('name'))
                tcovar = gvar.add_variable(tname)
                expr = str(tco.getAttribute('expr'))
                try:
                    tcovar.set_rhs(self._parse_expression(expr, tcovar))
                except myokit.ParseError as e:
                    warnings.warn('Error parsing expression for closed-to-open'
                                  ' transition in gate <' + gname + '>: ' +
                                  myokit.format_parse_error(e))
                    tcovar.meta['expression'] = str(expr)

                # Add open-to-closed transition
                tname = self._sanitise_name(toc.getAttribute('name'))
                tocvar = gvar.add_variable(tname)
                expr = str(toc.getAttribute('expr'))
                try:
                    tocvar.set_rhs(self._parse_expression(expr, tocvar))
                except myokit.ParseError as e:
                    warnings.warn('Error parsing expression for open-to-closed'
                                  ' transition in gate <' + gname + '>: ' +
                                  myokit.format_parse_error(e))
                    tocvar.meta['expression'] = str(expr)

                # Write equation for gate
                gvar.set_rhs(
                    Minus(Multiply(Name(tcovar), Minus(Number(1), Name(gvar))),
                          Multiply(Name(tocvar), Name(gvar))))

            else:
                # Use "steady-state & time_course" definition
                ss = gate.getElementsByTagName('steady_state')
                tc = gate.getElementsByTagName('time_course')
                if len(ss) < 1 or len(tc) < 1:
                    raise ChannelMLError(
                        'Unable to find transitions or steady state and'
                        ' time course for gate <' + gname + '>.')
                ss = ss[0]
                tc = tc[0]

                # Add steady-state variable
                ssname = self._sanitise_name(ss.getAttribute('name'))
                ssvar = gvar.add_variable(ssname)
                expr = str(ss.getAttribute('expr'))
                try:
                    ssvar.set_rhs(self._parse_expression(expr, ssvar))
                except myokit.ParseError as e:
                    warnings.warn(
                        'Error parsing expression for steady state in gate <' +
                        gname + '>: ' + myokit.format_parse_error(e))
                    ssvar.meta['expression'] = str(expr)

                # Add time course variable
                tcname = self._sanitise_name(tc.getAttribute('name'))
                tcvar = gvar.add_variable(tcname)
                expr = str(tc.getAttribute('expr'))
                try:
                    tcvar.set_rhs(self._parse_expression(expr, tcvar))
                except myokit.ParseError as e:
                    warnings.warn(
                        'Error parsing expression for time course in gate <' +
                        gname + '>: ' + myokit.format_parse_error(e))
                    tcvar.meta['expression'] = str(expr)

                # Write expression for gate
                gvar.set_rhs(
                    Divide(Minus(Name(ssvar), Name(gvar)), Name(tcvar)))

            power = int(gate.getAttribute('instances'))
            if power > 1:
                gvars.append(Power(Name(gvar), Number(power)))
            else:
                gvars.append(Name(gvar))

        if len(gvars) < 1:
            raise ChannelMLError(
                'Current voltage relation requires at least one gate.')

        # Add current variable
        ivar = component.add_variable('I')
        ivar.meta['desc'] = 'Current'
        expr = Name(gmaxvar)
        while gvars:
            expr = Multiply(expr, gvars.pop())
        expr = Multiply(expr, Minus(Name(vvar), Name(evar)))
        ivar.set_rhs(expr)

        # Done, return component
        return component
Example #4
0
    def model_creation(self):
        # Create a model
        m = Model('LotkaVolterra')

        # Add the first component
        X = m.add_component('X')
        self.assertEqual(X.qname(), 'X')
        self.assertEqual(X.parent(), m)
        self.assertIsInstance(X, Component)
        self.assertIn(X.qname(), m)
        self.assertEqual(len(m), 1)

        # Add variable a
        a = X.add_variable('a')
        self.assertEqual(a, a)
        self.assertIsInstance(a, Variable)
        self.assertEqual(len(X), 1)
        self.assertIn(a.name(), X)
        a.set_rhs(3)
        self.assertFalse(a.is_state())
        self.assertFalse(a.is_intermediary())
        self.assertTrue(a.is_constant())
        self.assertEqual(a.lhs(), Name(a))
        self.assertEqual(a.rhs(), Number(3))
        self.assertEqual(a.rhs().eval(), 3)
        self.assertEqual(a.code(), 'a = 3\n')
        self.assertEqual(a.eq().code(), 'X.a = 3')
        self.assertEqual(a.lhs().code(), 'X.a')
        self.assertEqual(a.rhs().code(), '3')
        self.assertEqual(a.eq(), Equation(Name(a), Number(3)))

        # Check lhs
        a_name1 = myokit.Name(a)
        a_name2 = myokit.Name(a)
        self.assertEqual(a_name1, a_name1)
        self.assertEqual(a_name2, a_name2)
        self.assertEqual(a_name1, a_name2)
        self.assertEqual(a_name2, a_name1)

        # Add variable b with two temporary variables
        b = X.add_variable('b')
        self.assertIsInstance(b, Variable)
        self.assertEqual(len(X), 2)
        self.assertIn(b.name(), X)
        b1 = b.add_variable('b1')
        self.assertEqual(len(b), 1)
        self.assertIn(b1.name(), b)
        self.assertIsInstance(b1, Variable)
        b2 = b.add_variable('b2')
        self.assertEqual(len(b), 2)
        self.assertIn(b2.name(), b)
        self.assertIsInstance(b2, Variable)
        b1.set_rhs(1)
        b2.set_rhs(Minus(Minus(Name(a), Name(b1)), Number(1)))
        b.set_rhs(Plus(Name(b1), Name(b2)))
        self.assertEqual(b.rhs().eval(), 2)
        self.assertFalse(b.is_state())
        self.assertFalse(b.is_intermediary())
        self.assertTrue(b.is_constant())
        self.assertEqual(b.lhs(), Name(b))

        # Add state variable x
        x = X.add_variable('x')
        x.set_rhs(10)
        x.promote()
        self.assertNotEqual(x, X)
        self.assertIsInstance(x, Variable)
        self.assertEqual(len(X), 3)
        self.assertIn(x.name(), X)
        self.assertTrue(x.is_state())
        self.assertFalse(x.is_intermediary())
        self.assertFalse(x.is_constant())
        self.assertEqual(x.lhs(), Derivative(Name(x)))

        # Test demoting, promoting
        x.demote()
        self.assertFalse(x.is_state())
        self.assertFalse(x.is_intermediary())
        self.assertTrue(x.is_constant())
        self.assertEqual(x.lhs(), Name(x))
        x.promote()
        self.assertTrue(x.is_state())
        self.assertFalse(x.is_intermediary())
        self.assertFalse(x.is_constant())
        self.assertEqual(x.lhs(), Derivative(Name(x)))
        x.demote()
        x.promote()
        x.demote()
        x.promote()
        self.assertTrue(x.is_state())
        self.assertFalse(x.is_intermediary())
        self.assertFalse(x.is_constant())
        self.assertEqual(x.lhs(), Derivative(Name(x)))

        # Add second component, variables
        Y = m.add_component('Y')
        self.assertNotEqual(X, Y)
        self.assertEqual(len(m), 2)
        c = Y.add_variable('c')
        c.set_rhs(Minus(Name(a), Number(1)))
        d = Y.add_variable('d')
        d.set_rhs(2)
        y = Y.add_variable('y')
        y.promote()

        # Set rhs for x and y
        x.set_rhs(
            Minus(Multiply(Name(a), Name(x)),
                  Multiply(Multiply(Name(b), Name(x)), Name(y))))
        x.set_state_value(10)
        self.assertEqual(x.rhs().code(), 'X.a * X.x - X.b * X.x * Y.y')
        y.set_rhs(
            Plus(Multiply(PrefixMinus(Name(c)), Name(y)),
                 Multiply(Multiply(Name(d), Name(x)), Name(y))))
        y.set_state_value(5)
        self.assertEqual(y.rhs().code(), '-Y.c * Y.y + Y.d * X.x * Y.y')

        # Add another component, variables
        Z = m.add_component('Z')
        self.assertNotEqual(X, Z)
        self.assertNotEqual(Y, Z)
        self.assertEqual(len(m), 3)
        t = Z.add_variable('total')
        self.assertEqual(t.name(), 'total')
        self.assertEqual(t.qname(), 'Z.total')
        self.assertEqual(t.qname(X), 'Z.total')
        self.assertEqual(t.qname(Z), 'total')
        t.set_rhs(Plus(Name(x), Name(y)))
        self.assertFalse(t.is_state())
        self.assertFalse(t.is_constant())
        self.assertTrue(t.is_intermediary())
        self.assertEqual(t.rhs().code(), 'X.x + Y.y')
        self.assertEqual(t.rhs().code(X), 'x + Y.y')
        self.assertEqual(t.rhs().code(Y), 'X.x + y')
        self.assertEqual(t.rhs().code(Z), 'X.x + Y.y')

        # Add engine component
        E = m.add_component('engine')
        self.assertNotEqual(X, E)
        self.assertNotEqual(Y, E)
        self.assertNotEqual(Z, E)
        self.assertEqual(len(m), 4)
        time = E.add_variable('time')
        time.set_rhs(0)
        self.assertIsNone(time.binding())
        time.set_binding('time')
        self.assertIsNotNone(time.binding())

        # Check state
        state = [i for i in m.states()]
        self.assertEqual(len(state), 2)
        self.assertIn(x, state)
        self.assertIn(y, state)

        # Test variable iterators
        def has(*v):
            for var in v:
                self.assertIn(var, vrs)
            self.assertEqual(len(vrs), len(v))

        vrs = [i for i in m.variables()]
        has(a, b, c, d, x, y, t, time)
        vrs = [i for i in m.variables(deep=True)]
        has(a, b, c, d, x, y, t, b1, b2, time)
        vrs = [i for i in m.variables(const=True)]
        has(a, b, c, d)
        vrs = [i for i in m.variables(const=True, deep=True)]
        has(a, b, c, d, b1, b2)
        vrs = [i for i in m.variables(const=False)]
        has(x, y, t, time)
        vrs = [i for i in m.variables(const=False, deep=True)]
        has(x, y, t, time)
        vrs = [i for i in m.variables(state=True)]
        has(x, y)
        vrs = [i for i in m.variables(state=True, deep=True)]
        has(x, y)
        vrs = [i for i in m.variables(state=False)]
        has(a, b, c, d, t, time)
        vrs = [i for i in m.variables(state=False, deep=True)]
        has(a, b, c, d, t, b1, b2, time)
        vrs = [i for i in m.variables(inter=True)]
        has(t)
        vrs = [i for i in m.variables(inter=True, deep=True)]
        has(t)
        vrs = [i for i in m.variables(inter=False)]
        has(a, b, c, d, x, y, time)
        vrs = [i for i in m.variables(inter=False, deep=True)]
        has(a, b, c, d, x, y, b1, b2, time)
        vrs = [i for i in m.variables(const=True, state=True)]
        has()
        vrs = [i for i in m.variables(const=True, state=False)]
        has(a, b, c, d)

        # Test equation iteration
        eq = [eq for eq in X.equations(deep=False)]
        self.assertEqual(len(eq), 3)
        eq = [eq for eq in X.equations(deep=True)]
        self.assertEqual(len(eq), 5)
        eq = [eq for eq in Y.equations(deep=False)]
        self.assertEqual(len(eq), 3)
        eq = [eq for eq in Y.equations(deep=True)]
        self.assertEqual(len(eq), 3)
        eq = [eq for eq in Z.equations(deep=False)]
        self.assertEqual(len(eq), 1)
        eq = [eq for eq in Z.equations(deep=True)]
        self.assertEqual(len(eq), 1)
        eq = [eq for eq in E.equations(deep=False)]
        self.assertEqual(len(eq), 1)
        eq = [eq for eq in E.equations(deep=True)]
        self.assertEqual(len(eq), 1)
        eq = [eq for eq in m.equations(deep=False)]
        self.assertEqual(len(eq), 8)
        eq = [eq for eq in m.equations(deep=True)]
        self.assertEqual(len(eq), 10)

        # Test dependency mapping
        def has(var, *dps):
            lst = vrs[m.get(var).lhs() if type(var) == str else var]
            self.assertEqual(len(lst), len(dps))
            for d in dps:
                d = m.get(d).lhs() if type(d) == str else d
                self.assertIn(d, lst)

        vrs = m.map_shallow_dependencies(omit_states=False)
        self.assertEqual(len(vrs), 12)
        has('X.a')
        has('X.b', 'X.b.b1', 'X.b.b2')
        has('X.b.b1')
        has('X.b.b2', 'X.a', 'X.b.b1')
        has('X.x', 'X.a', 'X.b', Name(x), Name(y))
        has(Name(x))
        has('Y.c', 'X.a')
        has('Y.d')
        has('Y.y', 'Y.c', 'Y.d', Name(x), Name(y))
        has(Name(y))
        has('Z.total', Name(x), Name(y))
        vrs = m.map_shallow_dependencies()
        self.assertEqual(len(vrs), 10)
        has('X.a')
        has('X.b', 'X.b.b1', 'X.b.b2')
        has('X.b.b1')
        has('X.b.b2', 'X.a', 'X.b.b1')
        has('X.x', 'X.a', 'X.b')
        has('Y.c', 'X.a')
        has('Y.d')
        has('Y.y', 'Y.c', 'Y.d')
        has('Z.total')
        vrs = m.map_shallow_dependencies(collapse=True)
        self.assertEqual(len(vrs), 8)
        has('X.a')
        has('X.b', 'X.a')
        has('X.x', 'X.a', 'X.b')
        has('Y.c', 'X.a')
        has('Y.d')
        has('Y.y', 'Y.c', 'Y.d')
        has('Z.total')
        # Validate
        m.validate()
        # Get solvable order
        order = m.solvable_order()
        self.assertEqual(len(order), 5)
        self.assertIn('*remaining*', order)
        self.assertIn('X', order)
        self.assertIn('Y', order)
        self.assertIn('Z', order)
        # Check that X comes before Y
        pos = dict([(name, k) for k, name in enumerate(order)])
        self.assertLess(pos['X'], pos['Y'])
        self.assertEqual(pos['*remaining*'], 4)
        # Check component equation lists
        eqs = order['*remaining*']
        self.assertEqual(len(eqs), 0)
        eqs = order['Z']
        self.assertEqual(len(eqs), 1)
        self.assertEqual(eqs[0].code(), 'Z.total = X.x + Y.y')
        eqs = order['Y']
        self.assertEqual(len(eqs), 3)
        self.assertEqual(eqs[2].code(),
                         'dot(Y.y) = -Y.c * Y.y + Y.d * X.x * Y.y')
        eqs = order['X']
        self.assertEqual(len(eqs), 5)
        self.assertEqual(eqs[0].code(), 'X.a = 3')
        self.assertEqual(eqs[1].code(), 'b1 = 1')
        self.assertEqual(eqs[2].code(), 'b2 = X.a - b1 - 1')
        self.assertEqual(eqs[3].code(), 'X.b = b1 + b2')
        # Test model export and cloning
        code1 = m.code()
        code2 = m.clone().code()
        self.assertEqual(code1, code2)
Example #5
0
 def remove_variable(self):
     """
     Tests the removal of a variable.
     """
     # Create a model
     m = Model('LotkaVolterra')
     # Add a variable 'a'
     X = m.add_component('X')
     # Simplest case
     a = X.add_variable('a')
     self.assertEqual(X.count_variables(), 1)
     X.remove_variable(a)
     self.assertEqual(X.count_variables(), 0)
     self.assertRaises(Exception, X.remove_variable, a)
     # Test re-adding
     a = X.add_variable('a')
     a.set_rhs(Number(5))
     self.assertEqual(X.count_variables(), 1)
     # Test deleting dependent variables
     b = X.add_variable('b')
     self.assertEqual(X.count_variables(), 2)
     b.set_rhs(Plus(Number(3), Name(a)))
     # Test blocking of removal
     self.assertRaises(IntegrityError, X.remove_variable, a)
     self.assertEqual(X.count_variables(), 2)
     # Test removal in the right order
     X.remove_variable(b)
     self.assertEqual(X.count_variables(), 1)
     X.remove_variable(a)
     self.assertEqual(X.count_variables(), 0)
     # Test reference to current state variable values
     a = X.add_variable('a')
     a.set_rhs(Number(5))
     a.promote()
     b = X.add_variable('b')
     b.set_rhs(Plus(Number(3), Name(a)))
     self.assertRaises(IntegrityError, X.remove_variable, a)
     X.remove_variable(b)
     X.remove_variable(a)
     self.assertEqual(X.count_variables(), 0)
     # Test reference to current state variable values with "self"-ref
     a = X.add_variable('a')
     a.promote()
     a.set_rhs(Name(a))
     X.remove_variable(a)
     # Test it doesn't interfere with normal workings
     a = X.add_variable('a')
     a.promote()
     a.set_rhs(Name(a))
     b = X.add_variable('b')
     b.set_rhs(Name(a))
     self.assertRaises(IntegrityError, X.remove_variable, a)
     X.remove_variable(b)
     X.remove_variable(a)
     # Test reference to dot
     a = X.add_variable('a')
     a.set_rhs(Number(5))
     a.promote()
     b = X.add_variable('b')
     b.set_rhs(Derivative(Name(a)))
     self.assertRaises(IntegrityError, X.remove_variable, a)
     X.remove_variable(b)
     X.remove_variable(a)
     # Test if orphaned
     self.assertIsNone(b.parent())
     # Test deleting variable with nested variables
     a = X.add_variable('a')
     b = a.add_variable('b')
     b.set_rhs(Plus(Number(2), Number(2)))
     a.set_rhs(Multiply(Number(3), Name(b)))
     self.assertRaises(IntegrityError, X.remove_variable, a)
     self.assertEqual(a.count_variables(), 1)
     self.assertEqual(X.count_variables(), 1)
     # Test recursive deleting
     X.remove_variable(a, recursive=True)
     self.assertEqual(a.count_variables(), 0)
     self.assertEqual(X.count_variables(), 0)
     # Same with dot(a) = a, b = 3 * a
     a = X.add_variable('a')
     a.promote(0.123)
     b = a.add_variable('b')
     b.set_rhs(Multiply(Number(3), Name(a)))
     a.set_rhs(Name(b))
     self.assertRaises(IntegrityError, X.remove_variable, a)
     self.assertRaises(IntegrityError, a.remove_variable, b)
     self.assertRaises(IntegrityError, a.remove_variable, b, True)
     X.remove_variable(a, recursive=True)