示例#1
0
    def test_numpy_writer(self):
        # Test NumPy expression writer obtaining method.

        import myokit
        w = myokit.numpy_writer()
        import myokit.formats
        import myokit.formats.python
        self.assertIsInstance(w, myokit.formats.python.NumPyExpressionWriter)

        # Test custom name method for this writer
        e = myokit.parse_expression('5 + 3 * x')
        self.assertEqual(w.ex(e), '5.0 + 3.0 * x')

        # Test with unvalidated model (no unames set)
        m = myokit.Model()
        c = m.add_component('c')
        x = c.add_variable('x')
        x.set_rhs('5 + x')
        self.assertEqual(w.ex(x.rhs()), '5.0 + c_x')
示例#2
0
    def __init__(self, model, states, parameters=None, current=None, vm=None):
        super(HHModel, self).__init__()

        #
        # Check input
        #
        if not isinstance(model, myokit.Model):
            raise ValueError('First argument must be a myokit.Model.')

        # Check membrane potential variable is known, and is a qname
        if vm is None:
            vm = model.label('membrane_potential')
        if vm is None:
            raise HHModelError(
                'A membrane potential must be specified as `vm` or using the'
                ' label `membrane_potential`.')
        if isinstance(vm, myokit.Variable):
            vm = vm.qname()

        # Ensure all HH-states in the model are written in inf-tau form
        # This returns a clone of the original model
        self._model = convert_hh_states_to_inf_tau_form(model, vm)
        del (model)

        # Check and collect state variables
        self._states = []
        for state in states:
            if isinstance(state, myokit.Variable):
                state = state.qname()
            try:
                state = self._model.get(str(state), myokit.Variable)
            except KeyError:
                raise HHModelError('Unknown state: <' + str(state) + '>.')
            if not state.is_state():
                raise HHModelError('Variable <' + state.qname() +
                                   '> is not a state.')
            if state in self._states:
                raise HHModelError('State <' + state.qname() +
                                   '> was added twice.')
            self._states.append(state)
        del (states)

        # Check and collect parameter variables
        unique = set()
        self._parameters = []
        if parameters is None:
            parameters = []
        for parameter in parameters:
            if isinstance(parameter, myokit.Variable):
                parameter = parameter.qname()
            try:
                parameter = self._model.get(parameter, myokit.Variable)
            except KeyError:
                raise HHModelError('Unknown parameter: <' + str(parameter) +
                                   '>.')
            if not parameter.is_literal():
                raise HHModelError('Unsuitable parameter: <' + str(parameter) +
                                   '>.')
            if parameter in unique:
                raise HHModelError('Parameter listed twice: <' +
                                   str(parameter) + '>.')
            unique.add(parameter)
            self._parameters.append(parameter)
        del (unique)
        del (parameters)

        # Check current variable
        if current is not None:
            if isinstance(current, myokit.Variable):
                current = current.qname()
            current = self._model.get(current, myokit.Variable)
            if current.is_state():
                raise HHModelError('Current variable can not be a state.')
        self._current = current
        del (current)

        # Check membrane potential variable
        self._membrane_potential = self._model.get(vm)
        if self._membrane_potential in self._parameters:
            raise HHModelError(
                'Membrane potential should not be included in the list of'
                ' parameters.')
        if self._membrane_potential in self._states:
            raise HHModelError(
                'The membrane potential should not be included in the list of'
                ' states.')
        if self._membrane_potential == self._current:
            raise HHModelError(
                'The membrane potential should not be the current variable.')
        del (vm)

        #
        # Demote unnecessary states and remove bindings
        #
        # Get values of all states
        # Note: Do this _before_ changing the model!
        self._default_state = np.array([v.state_value() for v in self._states])

        # Freeze remaining, non-current-model states
        s = self._model.state()  # Get state values before changing anything!
        # Note: list() cast is required so that we iterate over a static list,
        # otherwise we can get issues because the iterator depends on the model
        # (which we're changing).
        for k, state in enumerate(list(self._model.states())):
            if state not in self._states:
                state.demote()
                state.set_rhs(s[k])

        # Unbind everything except time
        for label, var in self._model.bindings():
            if label != 'time':
                var.set_binding(None)

        # Check if current variable depends on selected states
        # (At this point, anything not dependent on the states is a constant)
        if self._current is not None and self._current.is_constant():
            raise HHModelError(
                'Current must be a function of the selected state variables.')

        # Ensure all states are written in inf-tau form
        for state in self._states:
            if not has_inf_tau_form(state, self._membrane_potential):
                raise HHModelError('State `' + state.qname() +
                                   '` must have "inf-tau form" or'
                                   ' "alpha-beta form". See'
                                   ' `myokit.lib.hh.has_inf_tau_form()` and'
                                   ' `myokit.lib.hh.has_alpha_beta_form()`.')

        #
        # Remove unused variables from internal model, and evaluate any
        # literals.
        #
        # 1. Make sure that current variable is maintained by temporarily
        #    setting it as a state variable.
        # 2. Similarly, make sure parameters and membrane potential are not
        #    evaluated and/or removed
        #
        self._membrane_potential.promote(0)
        if self._current is not None:
            self._current.promote(0)
        for p in self._parameters:
            p.promote(0)

        # Evaluate all constants and remove unused variables
        for var in self._model.variables(deep=True, const=True):
            var.set_rhs(var.rhs().eval())
        self._model.validate(remove_unused_variables=True)

        self._membrane_potential.demote()
        for p in self._parameters:
            p.demote()
        if self._current is not None:
            self._current.demote()

        # Validate modified model
        self._model.validate()

        #
        # Create functions
        #

        # Create a list of inputs to the functions
        self._inputs = [self._membrane_potential] + self._parameters

        # Get the default values for all inputs
        self._default_inputs = np.array([v.eval() for v in self._inputs])

        #
        # Create a function that calculates the states analytically
        #
        # The created self._function has signature _f(_y0, _t, v, params*)
        # and returns a tuple (states, current). If _t is a scalar, states is a
        # sequence of state values and current is a single current value. If _t
        # is a numpy array of times, states is a sequence of arrays, and
        # current is a numpy array. If no current variable is known current is
        # always None.
        #
        f = []
        args = ['_y0', '_t'] + [i.uname() for i in self._inputs]
        f.append('def _f(' + ', '.join(args) + '):')
        f.append('_y = [0]*' + str(len(self._states)))

        # Add equations to calculate all infs and taus
        w = myokit.numpy_writer()
        order = self._model.solvable_order()
        ignore = set(self._inputs + self._states + [self._model.time()])
        for group in order.values():
            for eq in group:
                var = eq.lhs.var()
                if var in ignore or var == self._current:
                    continue
                f.append(w.eq(eq))

        # Add equations to calculate updated state
        for k, state in enumerate(self._states):
            inf, tau = get_inf_and_tau(state, self._membrane_potential)
            inf = w.ex(myokit.Name(inf))
            tau = w.ex(myokit.Name(tau))
            k = str(k)
            f.append('_y[' + k + '] = ' + state.uname() + ' = ' + inf +
                     ' + (_y0[' + k + '] - ' + inf + ') * numpy.exp(-_t / ' +
                     tau + ')')

        # Add current calculation
        if self._current is not None:
            f.append('_i = ' + w.ex(self._current.rhs()))
        else:
            f.append('_i = None')

        # Add return statement and create python function from f
        f.append('return _y, _i')
        for i in range(1, len(f)):
            f[i] = '    ' + f[i]
        f = '\n'.join(f)
        #print(f)
        local = {}
        exec(f, {'numpy': np}, local)
        self._function = local['_f']

        #
        #
        # Create a function that calculates the steady states
        #
        #
        g = []
        args = [i.uname() for i in self._inputs]
        g.append('def _g(' + ', '.join(args) + '):')
        g.append('_y = [0]*' + str(len(self._states)))
        for k, state in enumerate(self._states):
            k = str(k)
            inf, tau = get_inf_and_tau(state, self._membrane_potential)
            inf = inf.rhs().clone(expand=True, retain=self._inputs)
            g.append('_y[' + str(k) + '] = ' + w.ex(inf))

        # Create python function from g
        g.append('return _y')
        for i in range(1, len(g)):
            g[i] = '    ' + g[i]
        g = '\n'.join(g)
        local = {}
        exec(g, {'numpy': np}, local)
        self._steady_state_function = local['_g']