def test_no_method_printed_for(s_model):
    lut = LookupTables(s_model)
    lut.calc_lookup_tables(s_model.equations)
    printer = ChastePrinter(lookup_table_function=lut.print_lut_expr)
    output = ""
    for eq in s_model.equations:
        output += printer.doprint(eq)
    assert '_lt_0_row[0]' not in output
    # We can't use an external text files here since the data contains sets
    # printing these would make the test dependant on their order
    params_for_printing = lut.print_lookup_parameters(printer)

    assert len(params_for_printing) == 1
    assert sorted(params_for_printing[0].keys()) == \
        ['lookup_epxrs', 'mTableMaxs', 'mTableMins', 'mTableSteps', 'metadata_tag', 'table_used_in_methods', 'var']
    assert params_for_printing[0]['metadata_tag'] == 'membrane_voltage'
    assert params_for_printing[0]['metadata_tag'] == 'membrane_voltage'
    assert params_for_printing[0]['mTableMins'] == -250.0
    assert params_for_printing[0]['mTableMaxs'] == 550.0
    assert params_for_printing[0]['mTableSteps'] == 0.001
    assert params_for_printing[0]['table_used_in_methods'] == set()
    assert params_for_printing[0]['var'] == 'cell$V'

    expected = open(os.path.join(TESTS_FOLDER, 'test_lookup_tables_no_method_printed_for.txt'), 'r').read()
    assert str(params_for_printing[0]['lookup_epxrs']) == expected, str(params_for_printing[0]['lookup_epxrs'])
def test_no_print_after_table(s_model):
    lut = LookupTables(s_model)
    printer = ChastePrinter(lookup_table_function=lut.print_lut_expr)
    assert lut.print_lookup_parameters(printer) == []

    with pytest.raises(ValueError, match="Cannot print lookup expression after main table has been printed"):
        for eq in s_model.equations:
            printer.doprint(eq)
class OptCvodeChasteModel(CvodeChasteModel):
    """ Holds information specific for the Cvode Optimised model type. Builds on Cvode model type"""
    def __init__(self, model, file_name, **kwargs):
        self._lookup_tables = LookupTables(model,
                                           lookup_params=kwargs.get(
                                               'lookup_table',
                                               DEFAULT_LOOKUP_PARAMETERS))

        super().__init__(model, file_name, **kwargs)
        self._vars_for_template['model_type'] += 'Opt'
        self._vars_for_template[
            'lookup_parameters'] = self._lookup_tables.print_lookup_parameters(
                self._printer)

    def _get_stimulus(self):
        """ Get the partially evaluated stimulus currents in the model"""
        return_stim_eqs = super()._get_stimulus()
        return partial_eval(
            return_stim_eqs,
            self._model.stimulus_params | self._model.modifiable_parameters)

    def _get_extended_ionic_vars(self):
        """ Get the partially evaluated equations defining the ionic derivatives and all dependant equations"""
        extended_ionic_vars = partial_eval(super()._get_extended_ionic_vars(),
                                           self._model.ionic_vars)
        self._lookup_tables.calc_lookup_tables(extended_ionic_vars)
        return extended_ionic_vars

    def _get_derivative_equations(self):
        """ Get partially evaluated equations defining the derivatives including V (self._model.membrane_voltage_var)"""
        derivative_equations = partial_eval(
            super()._get_derivative_equations(), self._model.y_derivatives)
        self._lookup_tables.calc_lookup_tables(derivative_equations)
        return derivative_equations

    def _add_printers(self):
        """ Initialises Printers for outputting chaste code. """
        super()._add_printers(
            lookup_table_function=self._lookup_tables.print_lut_expr)

    def _print_jacobian(self):
        with self._lookup_tables.method_being_printed(
                'EvaluateAnalyticJacobian'):
            return super()._print_jacobian()

    def _format_ionic_vars(self):
        """ Format equations and dependant equations ionic derivatives"""
        with self._lookup_tables.method_being_printed('GetIIonic'):
            return super()._format_ionic_vars()

    def _format_derivative_equations(self, derivative_equations):
        """ Format derivative equations for chaste output"""
        with self._lookup_tables.method_being_printed('EvaluateYDerivatives'):
            return super()._format_derivative_equations(derivative_equations)

    def _format_derived_quant_eqs(self):
        """ Format equations for derived quantities based on current settings"""
        with self._lookup_tables.method_being_printed(
                'ComputeDerivedQuantities'):
            return super()._format_derived_quant_eqs()
    def __init__(self, model, file_name, **kwargs):
        self._lookup_tables = LookupTables(model,
                                           lookup_params=kwargs.get(
                                               'lookup_table',
                                               DEFAULT_LOOKUP_PARAMETERS))

        super().__init__(model, file_name, **kwargs)
        self._vars_for_template['model_type'] += 'Opt'
        self._vars_for_template[
            'lookup_parameters'] = self._lookup_tables.print_lookup_parameters(
                self._printer)
def test_nested_method_printed_for(s_model):
    lut = LookupTables(s_model)
    lut.calc_lookup_tables(s_model.equations)
    printer = ChastePrinter(lookup_table_function=lut.print_lut_expr)

    output = ""
    for eq in s_model.equations:
        with lut.method_being_printed('outer_method'):
            with lut.method_being_printed('innter_method'):
                output += printer.doprint(eq.rhs)
    assert '_lt_0_row[0]' in output

    params_for_printing = lut.print_lookup_parameters(printer)

    assert len(params_for_printing) == 1
    assert sorted(params_for_printing[0].keys()) == \
        ['lookup_epxrs', 'mTableMaxs', 'mTableMins', 'mTableSteps', 'metadata_tag', 'table_used_in_methods', 'var']
    assert params_for_printing[0]['metadata_tag'] == 'membrane_voltage'
    assert params_for_printing[0]['metadata_tag'] == 'membrane_voltage'
    assert params_for_printing[0]['mTableMins'] == -250.0
    assert params_for_printing[0]['mTableMaxs'] == 550.0
    assert params_for_printing[0]['mTableSteps'] == 0.001
    assert params_for_printing[0]['table_used_in_methods'] == set({'outer_method'})
    assert params_for_printing[0]['var'] == 'cell$V'

    expected = open(os.path.join(TESTS_FOLDER, 'test_lookup_tables_nested_method_printed_for.txt'), 'r').read()
    assert str(params_for_printing[0]['lookup_epxrs']) == expected, str(params_for_printing[0]['lookup_epxrs'])
def test_no_calc_after_print(s_model):
    lut = LookupTables(s_model)
    lut.calc_lookup_tables(s_model.equations)
    printer = ChastePrinter(lookup_table_function=lut.print_lut_expr)

    for eq in s_model.equations:
        printer.doprint(eq)

    with pytest.raises(ValueError, match="Cannot calculate lookup tables after printing has started"):
        lut.calc_lookup_tables(s_model.equations)
def test_change_lookup_table(be_model):
    lut = LookupTables(be_model, lookup_params=[['membrane_voltage', -25.0001, 54.9999, 0.01],
                                                ['cytosolic_calcium_concentration', 0.0, 50.0, 0.01],
                                                ['unknown_tag', 0.0, 50.0, 0.01]])
    lut.calc_lookup_tables(be_model.equations)
    printer = ChastePrinter(lookup_table_function=lut.print_lut_expr)

    output = ""
    for eq in be_model.equations:
        with lut.method_being_printed('template_method'):
            output += printer.doprint(eq.rhs)
    assert '_lt_0_row[0]' in output
    assert '_lt_1_row[0]' in output
    assert '_lt_2_row[0]' not in output

    params_for_printing = lut.print_lookup_parameters(printer)

    assert len(params_for_printing) == 2
    assert all([sorted(p.keys()) ==
                ['lookup_epxrs', 'mTableMaxs', 'mTableMins', 'mTableSteps', 'metadata_tag',
                 'table_used_in_methods', 'var'] for p in params_for_printing])
    assert params_for_printing[0]['metadata_tag'] == 'membrane_voltage'
    assert params_for_printing[0]['mTableMins'] == -25.0001
    assert params_for_printing[0]['mTableMaxs'] == 54.9999
    assert params_for_printing[0]['mTableSteps'] == 0.01
    assert params_for_printing[0]['table_used_in_methods'] == set({'template_method'})
    assert params_for_printing[0]['var'] == 'membrane$V'

    expected = open(os.path.join(TESTS_FOLDER, 'test_lookup_tables_change_lookup_table.txt'), 'r').read()
    assert str(params_for_printing[0]['lookup_epxrs']) == expected, str(params_for_printing[0]['lookup_epxrs'])

    assert params_for_printing[1]['metadata_tag'] == 'cytosolic_calcium_concentration'
    assert params_for_printing[1]['mTableMins'] == 0.0
    assert params_for_printing[1]['mTableMaxs'] == 50.0
    assert params_for_printing[1]['mTableSteps'] == 0.01
    assert params_for_printing[1]['table_used_in_methods'] == set({'template_method'})
    assert params_for_printing[1]['var'] == 'slow_inward_current$Cai'

    assert str(params_for_printing[1]['lookup_epxrs']) \
        == "[['-82.3 - 13.0287 * log(0.001 * slow_inward_current$Cai)', False]]"
class GeneralisedRushLarsenFirstOrderModelOpt(
        GeneralisedRushLarsenFirstOrderModel):
    """ Holds template and information specific for the GeneralisedRushLarsenOpt model type"""
    def __init__(self, model, file_name, **kwargs):
        self._lookup_tables = LookupTables(model,
                                           lookup_params=kwargs.get(
                                               'lookup_table',
                                               DEFAULT_LOOKUP_PARAMETERS))

        super().__init__(model, file_name, **kwargs)
        self._vars_for_template['model_type'] += 'Opt'
        self._vars_for_template[
            'lookup_parameters'] = self._lookup_tables.print_lookup_parameters(
                self._printer)

    def _get_stimulus(self):
        """ Get the partially evaluated stimulus currents in the model"""
        return_stim_eqs = super()._get_stimulus()
        return partial_eval(
            return_stim_eqs,
            self._model.stimulus_params | self._model.modifiable_parameters)

    def _get_extended_ionic_vars(self):
        """ Get the partially evaluated equations defining the ionic derivatives and all dependant equations"""
        extended_ionic_vars = partial_eval(super()._get_extended_ionic_vars(),
                                           self._model.ionic_vars)
        self._lookup_tables.calc_lookup_tables(extended_ionic_vars)
        return extended_ionic_vars

    def _get_derivative_equations(self):
        """ Get partially evaluated equations defining the derivatives including V (self._model.membrane_voltage_var)"""
        derivative_equations = partial_eval(
            super()._get_derivative_equations(), self._model.y_derivatives)
        self._lookup_tables.calc_lookup_tables(derivative_equations)
        return derivative_equations

    def _add_printers(self):
        """ Initialises Printers for outputting chaste code. """
        super()._add_printers(
            lookup_table_function=self._lookup_tables.print_lut_expr)

    def _format_ionic_vars(self):
        """ Format equations and dependant equations ionic derivatives"""
        with self._lookup_tables.method_being_printed('GetIIonic'):
            return super()._format_ionic_vars()

    def format_derivative_equation(self, eq, modifiers_with_defining_eqs):
        """ Format an individual derivative equation"""
        formatted_eq = None
        if isinstance(eq.lhs, Derivative
                      ) and eq.lhs.args[0] is self._model.membrane_voltage_var:
            formatted_eq = super().format_derivative_equation(
                eq, modifiers_with_defining_eqs)

        elif eq in self._derivative_eqs_excl_voltage:  # Indicate use of lookup table
            with self._lookup_tables.method_being_printed(
                    'ComputeOneStepExceptVoltage'):
                formatted_eq = super().format_derivative_equation(
                    eq, modifiers_with_defining_eqs)

        if eq in self._derivative_eqs_voltage:  # Indicate use of lookup table
            with self._lookup_tables.method_being_printed(
                    'UpdateTransmembranePotential'):
                formatted_eq = super().format_derivative_equation(
                    eq, modifiers_with_defining_eqs)

        assert formatted_eq is not None, (
            'Derivative equation should be dvdt or in _derivative_eqs_voltage '
            'or in _derivative_eqs_excl_voltage')
        return formatted_eq

    def _format_derived_quant_eqs(self):
        """ Format equations for derived quantities based on current settings"""
        with self._lookup_tables.method_being_printed(
                'ComputeDerivedQuantities'):
            return super()._format_derived_quant_eqs()

    def eq_in_evaluate_y_derivative(self, eq, used_equations):
        """Indicate if the lhs of equation eq appears in used_equations"""
        super().eq_in_evaluate_y_derivative(eq, used_equations)
        if eq['in_evaluate_y_derivative'][-1]:
            # Reprint to indicate use of lookup table
            with self._lookup_tables.method_being_printed(
                    'ComputeOneStepExceptVoltage' +
                    str(len(eq['in_evaluate_y_derivative']) - 1)):
                modifiers_with_defining_eqs = \
                    set((eq.lhs for eq in self._derivative_equations)) | self._model.state_vars
                eq['rhs'] = self._print_rhs_with_modifiers(
                    eq['sympy_lhs'], eq['sympy_rhs'],
                    modifiers_with_defining_eqs)

    def eq_in_evaluate_partial_derivative(self, eq, used_jacobian_vars):
        """Indicate if the lhs of equation eq appears in used_jacobian_vars"""
        super().eq_in_evaluate_partial_derivative(eq, used_jacobian_vars)
        if eq['in_evaluate_partial_derivative'][-1]:
            # Reprint to indicate use of lookup table
            with self._lookup_tables.method_being_printed(
                    'EvaluatePartialDerivative' +
                    str(len(eq['in_evaluate_partial_derivative']) - 1)):
                eq['rhs'] = self._printer.doprint(eq['sympy_rhs'])
Esempio n. 9
0
class BackwardEulerOptModel(BackwardEulerModel):
    """ Holds information specific for the Optimised Backward Euler model type."""
    def __init__(self, model, file_name, **kwargs):
        self._lookup_tables = LookupTables(model,
                                           lookup_params=kwargs.get(
                                               'lookup_table',
                                               DEFAULT_LOOKUP_PARAMETERS))

        super().__init__(model, file_name, **kwargs)
        self._vars_for_template['model_type'] += 'Opt'
        self._update_formatted_deriv_eq()
        self._vars_for_template[
            'lookup_parameters'] = self._lookup_tables.print_lookup_parameters(
                self._printer)

    def _get_stimulus(self):
        """ Get the partially evaluated stimulus currents in the model"""
        return_stim_eqs = super()._get_stimulus()
        return partial_eval(
            return_stim_eqs,
            self._model.stimulus_params | self._model.modifiable_parameters)

    def _get_extended_ionic_vars(self):
        """ Get the partially evaluated equations defining the ionic derivatives and all dependant equations"""
        extended_ionic_vars = partial_eval(super()._get_extended_ionic_vars(),
                                           self._model.ionic_vars)
        self._lookup_tables.calc_lookup_tables(extended_ionic_vars)
        return extended_ionic_vars

    def _get_derivative_equations(self):
        """ Get partially evaluated equations defining the derivatives including V (self._model.membrane_voltage_var)"""
        derivative_equations = partial_eval(
            super()._get_derivative_equations(), self._model.y_derivatives)
        self._lookup_tables.calc_lookup_tables(derivative_equations)
        return derivative_equations

    def _pre_print_hook(self):
        """ Retreives out linear and non-linear derivatives and the relevant jacobian for it.
            And calculates lookup tables for the jacobian."""
        super()._pre_print_hook()
        # calculate lookup tables for the jacobians created by the Backward Euler
        self._lookup_tables.calc_lookup_tables(
            (Eq(lhs, rhs) for lhs, rhs in self._jacobian_equations))

    def _update_formatted_deriv_eq(self):
        """Update derivatibve equation information for lookup table printing"""
        for eq in self._vars_for_template['y_derivative_equations']:
            if not eq['linear']:
                with self._lookup_tables.method_being_printed(
                        'ComputeResidual'):
                    eq['rhs'] = self._printer.doprint(eq['sympy_rhs'])
            if eq['in_membrane_voltage']:
                with self._lookup_tables.method_being_printed(
                        'UpdateTransmembranePotential'):
                    eq['rhs'] = self._printer.doprint(eq['sympy_rhs'])

    def _add_printers(self):
        """ Initialises Printers for outputting chaste code. """
        super()._add_printers(
            lookup_table_function=self._lookup_tables.print_lut_expr)

    def _format_ionic_vars(self):
        """ Format equations and dependant equations ionic derivatives"""
        with self._lookup_tables.method_being_printed('GetIIonic'):
            return super()._format_ionic_vars()

    def _format_derivative_equations(self, derivative_equations):
        """ Format derivative equations for chaste output"""
        with self._lookup_tables.method_being_printed('EvaluateYDerivatives'):
            return super()._format_derivative_equations(derivative_equations)

    def _format_derived_quant_eqs(self):
        """ Format equations for derived quantities based on current settings"""
        with self._lookup_tables.method_being_printed(
                'ComputeDerivedQuantities'):
            return super()._format_derived_quant_eqs()

    def format_linear_deriv_eqs(self, linear_deriv_eqs):
        """ Format linear derivative equations beloning, to update what belongs were"""
        with self._lookup_tables.method_being_printed(
                'ComputeOneStepExceptVoltage'):
            return super().format_linear_deriv_eqs(linear_deriv_eqs)

    def format_jacobian(self):
        """Format the jacobian to update what belongs were"""
        with self._lookup_tables.method_being_printed('ComputeJacobian'):
            return super().format_jacobian()
class RushLarsenOptModel(RushLarsenModel):
    """ Holds template and information specific for the RushLarsen model type"""
    def __init__(self, model, file_name, **kwargs):
        self._lookup_tables = LookupTables(model,
                                           lookup_params=kwargs.get(
                                               'lookup_table',
                                               DEFAULT_LOOKUP_PARAMETERS))

        super().__init__(model, file_name, **kwargs)
        self._vars_for_template['model_type'] += 'Opt'
        self._vars_for_template[
            'lookup_parameters'] = self._lookup_tables.print_lookup_parameters(
                self._printer)

    def _get_extended_ionic_vars(self):
        """ Get the partially evaluated equations defining the ionic derivatives and all dependant equations"""
        extended_ionic_vars = partial_eval(super()._get_extended_ionic_vars(),
                                           self._model.ionic_vars)
        self._lookup_tables.calc_lookup_tables(extended_ionic_vars)
        return extended_ionic_vars

    def _get_derivative_equations(self):
        """ Get partially evaluated equations defining the derivatives including V (self._model.membrane_voltage_var)"""
        derivative_equations = partial_eval(
            super()._get_derivative_equations(), self._model.y_derivatives)
        self._lookup_tables.calc_lookup_tables(derivative_equations)
        return derivative_equations

    def _add_printers(self):
        """ Initialises Printers for outputting chaste code. """
        super()._add_printers(
            lookup_table_function=self._lookup_tables.print_lut_expr)

    def _format_ionic_vars(self):
        """ Format equations and dependant equations ionic derivatives"""
        with self._lookup_tables.method_being_printed('GetIIonic'):
            return super()._format_ionic_vars()

    def _format_derivative_equations(self, derivative_equations):
        """ Format derivative equations for chaste output"""
        with self._lookup_tables.method_being_printed('EvaluateYDerivatives'):
            return super()._format_derivative_equations(derivative_equations)

    def _format_derived_quant_eqs(self):
        """ Format equations for derived quantities based on current settings"""
        with self._lookup_tables.method_being_printed(
                'ComputeDerivedQuantities'):
            return super()._format_derived_quant_eqs()

    def _get_formatted_alpha_beta(self):
        """Gets the information for r_alpha_or_tau, r_beta_or_inf in the c++ output and formatted equations

        Rearranges in the form (inf-x)/tau
        """
        with self._lookup_tables.method_being_printed('EvaluateEquations'):
            return super()._get_formatted_alpha_beta()

    def format_deriv_eqs_EvaluateEquations(self, deriv_eqs_EvaluateEquations):
        """ Format derivative equations beloning to EvaluateEquations, to update what equation belongs were"""
        voltage_eqs = set(
            get_equations_for(self._model, [
                d for d in self._model.y_derivatives
                if d.args[0] is self._model.membrane_voltage_var
            ]))
        other_eqs = set(
            get_equations_for(self._model, [
                d for d in self._model.y_derivatives
                if d.args[0] is not self._model.membrane_voltage_var
            ]))
        voltage_eqs -= set(other_eqs)
        self._derivative_eqs_voltage |= voltage_eqs
        return super().format_deriv_eqs_EvaluateEquations(
            deriv_eqs_EvaluateEquations)