def test_no_method_printed_for(s_model): lut = LookupTables(s_model) lut.calc_lookup_tables(s_model.equations) printer = ChastePrinter(lookup_table_function=lut.print_lut_expr) output = "" for eq in s_model.equations: output += printer.doprint(eq) assert '_lt_0_row[0]' not in output # We can't use an external text files here since the data contains sets # printing these would make the test dependant on their order params_for_printing = lut.print_lookup_parameters(printer) assert len(params_for_printing) == 1 assert sorted(params_for_printing[0].keys()) == \ ['lookup_epxrs', 'mTableMaxs', 'mTableMins', 'mTableSteps', 'metadata_tag', 'table_used_in_methods', 'var'] assert params_for_printing[0]['metadata_tag'] == 'membrane_voltage' assert params_for_printing[0]['metadata_tag'] == 'membrane_voltage' assert params_for_printing[0]['mTableMins'] == -250.0 assert params_for_printing[0]['mTableMaxs'] == 550.0 assert params_for_printing[0]['mTableSteps'] == 0.001 assert params_for_printing[0]['table_used_in_methods'] == set() assert params_for_printing[0]['var'] == 'cell$V' expected = open(os.path.join(TESTS_FOLDER, 'test_lookup_tables_no_method_printed_for.txt'), 'r').read() assert str(params_for_printing[0]['lookup_epxrs']) == expected, str(params_for_printing[0]['lookup_epxrs'])
def test_no_print_after_table(s_model): lut = LookupTables(s_model) printer = ChastePrinter(lookup_table_function=lut.print_lut_expr) assert lut.print_lookup_parameters(printer) == [] with pytest.raises(ValueError, match="Cannot print lookup expression after main table has been printed"): for eq in s_model.equations: printer.doprint(eq)
class OptCvodeChasteModel(CvodeChasteModel): """ Holds information specific for the Cvode Optimised model type. Builds on Cvode model type""" def __init__(self, model, file_name, **kwargs): self._lookup_tables = LookupTables(model, lookup_params=kwargs.get( 'lookup_table', DEFAULT_LOOKUP_PARAMETERS)) super().__init__(model, file_name, **kwargs) self._vars_for_template['model_type'] += 'Opt' self._vars_for_template[ 'lookup_parameters'] = self._lookup_tables.print_lookup_parameters( self._printer) def _get_stimulus(self): """ Get the partially evaluated stimulus currents in the model""" return_stim_eqs = super()._get_stimulus() return partial_eval( return_stim_eqs, self._model.stimulus_params | self._model.modifiable_parameters) def _get_extended_ionic_vars(self): """ Get the partially evaluated equations defining the ionic derivatives and all dependant equations""" extended_ionic_vars = partial_eval(super()._get_extended_ionic_vars(), self._model.ionic_vars) self._lookup_tables.calc_lookup_tables(extended_ionic_vars) return extended_ionic_vars def _get_derivative_equations(self): """ Get partially evaluated equations defining the derivatives including V (self._model.membrane_voltage_var)""" derivative_equations = partial_eval( super()._get_derivative_equations(), self._model.y_derivatives) self._lookup_tables.calc_lookup_tables(derivative_equations) return derivative_equations def _add_printers(self): """ Initialises Printers for outputting chaste code. """ super()._add_printers( lookup_table_function=self._lookup_tables.print_lut_expr) def _print_jacobian(self): with self._lookup_tables.method_being_printed( 'EvaluateAnalyticJacobian'): return super()._print_jacobian() def _format_ionic_vars(self): """ Format equations and dependant equations ionic derivatives""" with self._lookup_tables.method_being_printed('GetIIonic'): return super()._format_ionic_vars() def _format_derivative_equations(self, derivative_equations): """ Format derivative equations for chaste output""" with self._lookup_tables.method_being_printed('EvaluateYDerivatives'): return super()._format_derivative_equations(derivative_equations) def _format_derived_quant_eqs(self): """ Format equations for derived quantities based on current settings""" with self._lookup_tables.method_being_printed( 'ComputeDerivedQuantities'): return super()._format_derived_quant_eqs()
def __init__(self, model, file_name, **kwargs): self._lookup_tables = LookupTables(model, lookup_params=kwargs.get( 'lookup_table', DEFAULT_LOOKUP_PARAMETERS)) super().__init__(model, file_name, **kwargs) self._vars_for_template['model_type'] += 'Opt' self._vars_for_template[ 'lookup_parameters'] = self._lookup_tables.print_lookup_parameters( self._printer)
def test_nested_method_printed_for(s_model): lut = LookupTables(s_model) lut.calc_lookup_tables(s_model.equations) printer = ChastePrinter(lookup_table_function=lut.print_lut_expr) output = "" for eq in s_model.equations: with lut.method_being_printed('outer_method'): with lut.method_being_printed('innter_method'): output += printer.doprint(eq.rhs) assert '_lt_0_row[0]' in output params_for_printing = lut.print_lookup_parameters(printer) assert len(params_for_printing) == 1 assert sorted(params_for_printing[0].keys()) == \ ['lookup_epxrs', 'mTableMaxs', 'mTableMins', 'mTableSteps', 'metadata_tag', 'table_used_in_methods', 'var'] assert params_for_printing[0]['metadata_tag'] == 'membrane_voltage' assert params_for_printing[0]['metadata_tag'] == 'membrane_voltage' assert params_for_printing[0]['mTableMins'] == -250.0 assert params_for_printing[0]['mTableMaxs'] == 550.0 assert params_for_printing[0]['mTableSteps'] == 0.001 assert params_for_printing[0]['table_used_in_methods'] == set({'outer_method'}) assert params_for_printing[0]['var'] == 'cell$V' expected = open(os.path.join(TESTS_FOLDER, 'test_lookup_tables_nested_method_printed_for.txt'), 'r').read() assert str(params_for_printing[0]['lookup_epxrs']) == expected, str(params_for_printing[0]['lookup_epxrs'])
def test_no_calc_after_print(s_model): lut = LookupTables(s_model) lut.calc_lookup_tables(s_model.equations) printer = ChastePrinter(lookup_table_function=lut.print_lut_expr) for eq in s_model.equations: printer.doprint(eq) with pytest.raises(ValueError, match="Cannot calculate lookup tables after printing has started"): lut.calc_lookup_tables(s_model.equations)
def test_change_lookup_table(be_model): lut = LookupTables(be_model, lookup_params=[['membrane_voltage', -25.0001, 54.9999, 0.01], ['cytosolic_calcium_concentration', 0.0, 50.0, 0.01], ['unknown_tag', 0.0, 50.0, 0.01]]) lut.calc_lookup_tables(be_model.equations) printer = ChastePrinter(lookup_table_function=lut.print_lut_expr) output = "" for eq in be_model.equations: with lut.method_being_printed('template_method'): output += printer.doprint(eq.rhs) assert '_lt_0_row[0]' in output assert '_lt_1_row[0]' in output assert '_lt_2_row[0]' not in output params_for_printing = lut.print_lookup_parameters(printer) assert len(params_for_printing) == 2 assert all([sorted(p.keys()) == ['lookup_epxrs', 'mTableMaxs', 'mTableMins', 'mTableSteps', 'metadata_tag', 'table_used_in_methods', 'var'] for p in params_for_printing]) assert params_for_printing[0]['metadata_tag'] == 'membrane_voltage' assert params_for_printing[0]['mTableMins'] == -25.0001 assert params_for_printing[0]['mTableMaxs'] == 54.9999 assert params_for_printing[0]['mTableSteps'] == 0.01 assert params_for_printing[0]['table_used_in_methods'] == set({'template_method'}) assert params_for_printing[0]['var'] == 'membrane$V' expected = open(os.path.join(TESTS_FOLDER, 'test_lookup_tables_change_lookup_table.txt'), 'r').read() assert str(params_for_printing[0]['lookup_epxrs']) == expected, str(params_for_printing[0]['lookup_epxrs']) assert params_for_printing[1]['metadata_tag'] == 'cytosolic_calcium_concentration' assert params_for_printing[1]['mTableMins'] == 0.0 assert params_for_printing[1]['mTableMaxs'] == 50.0 assert params_for_printing[1]['mTableSteps'] == 0.01 assert params_for_printing[1]['table_used_in_methods'] == set({'template_method'}) assert params_for_printing[1]['var'] == 'slow_inward_current$Cai' assert str(params_for_printing[1]['lookup_epxrs']) \ == "[['-82.3 - 13.0287 * log(0.001 * slow_inward_current$Cai)', False]]"
class GeneralisedRushLarsenFirstOrderModelOpt( GeneralisedRushLarsenFirstOrderModel): """ Holds template and information specific for the GeneralisedRushLarsenOpt model type""" def __init__(self, model, file_name, **kwargs): self._lookup_tables = LookupTables(model, lookup_params=kwargs.get( 'lookup_table', DEFAULT_LOOKUP_PARAMETERS)) super().__init__(model, file_name, **kwargs) self._vars_for_template['model_type'] += 'Opt' self._vars_for_template[ 'lookup_parameters'] = self._lookup_tables.print_lookup_parameters( self._printer) def _get_stimulus(self): """ Get the partially evaluated stimulus currents in the model""" return_stim_eqs = super()._get_stimulus() return partial_eval( return_stim_eqs, self._model.stimulus_params | self._model.modifiable_parameters) def _get_extended_ionic_vars(self): """ Get the partially evaluated equations defining the ionic derivatives and all dependant equations""" extended_ionic_vars = partial_eval(super()._get_extended_ionic_vars(), self._model.ionic_vars) self._lookup_tables.calc_lookup_tables(extended_ionic_vars) return extended_ionic_vars def _get_derivative_equations(self): """ Get partially evaluated equations defining the derivatives including V (self._model.membrane_voltage_var)""" derivative_equations = partial_eval( super()._get_derivative_equations(), self._model.y_derivatives) self._lookup_tables.calc_lookup_tables(derivative_equations) return derivative_equations def _add_printers(self): """ Initialises Printers for outputting chaste code. """ super()._add_printers( lookup_table_function=self._lookup_tables.print_lut_expr) def _format_ionic_vars(self): """ Format equations and dependant equations ionic derivatives""" with self._lookup_tables.method_being_printed('GetIIonic'): return super()._format_ionic_vars() def format_derivative_equation(self, eq, modifiers_with_defining_eqs): """ Format an individual derivative equation""" formatted_eq = None if isinstance(eq.lhs, Derivative ) and eq.lhs.args[0] is self._model.membrane_voltage_var: formatted_eq = super().format_derivative_equation( eq, modifiers_with_defining_eqs) elif eq in self._derivative_eqs_excl_voltage: # Indicate use of lookup table with self._lookup_tables.method_being_printed( 'ComputeOneStepExceptVoltage'): formatted_eq = super().format_derivative_equation( eq, modifiers_with_defining_eqs) if eq in self._derivative_eqs_voltage: # Indicate use of lookup table with self._lookup_tables.method_being_printed( 'UpdateTransmembranePotential'): formatted_eq = super().format_derivative_equation( eq, modifiers_with_defining_eqs) assert formatted_eq is not None, ( 'Derivative equation should be dvdt or in _derivative_eqs_voltage ' 'or in _derivative_eqs_excl_voltage') return formatted_eq def _format_derived_quant_eqs(self): """ Format equations for derived quantities based on current settings""" with self._lookup_tables.method_being_printed( 'ComputeDerivedQuantities'): return super()._format_derived_quant_eqs() def eq_in_evaluate_y_derivative(self, eq, used_equations): """Indicate if the lhs of equation eq appears in used_equations""" super().eq_in_evaluate_y_derivative(eq, used_equations) if eq['in_evaluate_y_derivative'][-1]: # Reprint to indicate use of lookup table with self._lookup_tables.method_being_printed( 'ComputeOneStepExceptVoltage' + str(len(eq['in_evaluate_y_derivative']) - 1)): modifiers_with_defining_eqs = \ set((eq.lhs for eq in self._derivative_equations)) | self._model.state_vars eq['rhs'] = self._print_rhs_with_modifiers( eq['sympy_lhs'], eq['sympy_rhs'], modifiers_with_defining_eqs) def eq_in_evaluate_partial_derivative(self, eq, used_jacobian_vars): """Indicate if the lhs of equation eq appears in used_jacobian_vars""" super().eq_in_evaluate_partial_derivative(eq, used_jacobian_vars) if eq['in_evaluate_partial_derivative'][-1]: # Reprint to indicate use of lookup table with self._lookup_tables.method_being_printed( 'EvaluatePartialDerivative' + str(len(eq['in_evaluate_partial_derivative']) - 1)): eq['rhs'] = self._printer.doprint(eq['sympy_rhs'])
class BackwardEulerOptModel(BackwardEulerModel): """ Holds information specific for the Optimised Backward Euler model type.""" def __init__(self, model, file_name, **kwargs): self._lookup_tables = LookupTables(model, lookup_params=kwargs.get( 'lookup_table', DEFAULT_LOOKUP_PARAMETERS)) super().__init__(model, file_name, **kwargs) self._vars_for_template['model_type'] += 'Opt' self._update_formatted_deriv_eq() self._vars_for_template[ 'lookup_parameters'] = self._lookup_tables.print_lookup_parameters( self._printer) def _get_stimulus(self): """ Get the partially evaluated stimulus currents in the model""" return_stim_eqs = super()._get_stimulus() return partial_eval( return_stim_eqs, self._model.stimulus_params | self._model.modifiable_parameters) def _get_extended_ionic_vars(self): """ Get the partially evaluated equations defining the ionic derivatives and all dependant equations""" extended_ionic_vars = partial_eval(super()._get_extended_ionic_vars(), self._model.ionic_vars) self._lookup_tables.calc_lookup_tables(extended_ionic_vars) return extended_ionic_vars def _get_derivative_equations(self): """ Get partially evaluated equations defining the derivatives including V (self._model.membrane_voltage_var)""" derivative_equations = partial_eval( super()._get_derivative_equations(), self._model.y_derivatives) self._lookup_tables.calc_lookup_tables(derivative_equations) return derivative_equations def _pre_print_hook(self): """ Retreives out linear and non-linear derivatives and the relevant jacobian for it. And calculates lookup tables for the jacobian.""" super()._pre_print_hook() # calculate lookup tables for the jacobians created by the Backward Euler self._lookup_tables.calc_lookup_tables( (Eq(lhs, rhs) for lhs, rhs in self._jacobian_equations)) def _update_formatted_deriv_eq(self): """Update derivatibve equation information for lookup table printing""" for eq in self._vars_for_template['y_derivative_equations']: if not eq['linear']: with self._lookup_tables.method_being_printed( 'ComputeResidual'): eq['rhs'] = self._printer.doprint(eq['sympy_rhs']) if eq['in_membrane_voltage']: with self._lookup_tables.method_being_printed( 'UpdateTransmembranePotential'): eq['rhs'] = self._printer.doprint(eq['sympy_rhs']) def _add_printers(self): """ Initialises Printers for outputting chaste code. """ super()._add_printers( lookup_table_function=self._lookup_tables.print_lut_expr) def _format_ionic_vars(self): """ Format equations and dependant equations ionic derivatives""" with self._lookup_tables.method_being_printed('GetIIonic'): return super()._format_ionic_vars() def _format_derivative_equations(self, derivative_equations): """ Format derivative equations for chaste output""" with self._lookup_tables.method_being_printed('EvaluateYDerivatives'): return super()._format_derivative_equations(derivative_equations) def _format_derived_quant_eqs(self): """ Format equations for derived quantities based on current settings""" with self._lookup_tables.method_being_printed( 'ComputeDerivedQuantities'): return super()._format_derived_quant_eqs() def format_linear_deriv_eqs(self, linear_deriv_eqs): """ Format linear derivative equations beloning, to update what belongs were""" with self._lookup_tables.method_being_printed( 'ComputeOneStepExceptVoltage'): return super().format_linear_deriv_eqs(linear_deriv_eqs) def format_jacobian(self): """Format the jacobian to update what belongs were""" with self._lookup_tables.method_being_printed('ComputeJacobian'): return super().format_jacobian()
class RushLarsenOptModel(RushLarsenModel): """ Holds template and information specific for the RushLarsen model type""" def __init__(self, model, file_name, **kwargs): self._lookup_tables = LookupTables(model, lookup_params=kwargs.get( 'lookup_table', DEFAULT_LOOKUP_PARAMETERS)) super().__init__(model, file_name, **kwargs) self._vars_for_template['model_type'] += 'Opt' self._vars_for_template[ 'lookup_parameters'] = self._lookup_tables.print_lookup_parameters( self._printer) def _get_extended_ionic_vars(self): """ Get the partially evaluated equations defining the ionic derivatives and all dependant equations""" extended_ionic_vars = partial_eval(super()._get_extended_ionic_vars(), self._model.ionic_vars) self._lookup_tables.calc_lookup_tables(extended_ionic_vars) return extended_ionic_vars def _get_derivative_equations(self): """ Get partially evaluated equations defining the derivatives including V (self._model.membrane_voltage_var)""" derivative_equations = partial_eval( super()._get_derivative_equations(), self._model.y_derivatives) self._lookup_tables.calc_lookup_tables(derivative_equations) return derivative_equations def _add_printers(self): """ Initialises Printers for outputting chaste code. """ super()._add_printers( lookup_table_function=self._lookup_tables.print_lut_expr) def _format_ionic_vars(self): """ Format equations and dependant equations ionic derivatives""" with self._lookup_tables.method_being_printed('GetIIonic'): return super()._format_ionic_vars() def _format_derivative_equations(self, derivative_equations): """ Format derivative equations for chaste output""" with self._lookup_tables.method_being_printed('EvaluateYDerivatives'): return super()._format_derivative_equations(derivative_equations) def _format_derived_quant_eqs(self): """ Format equations for derived quantities based on current settings""" with self._lookup_tables.method_being_printed( 'ComputeDerivedQuantities'): return super()._format_derived_quant_eqs() def _get_formatted_alpha_beta(self): """Gets the information for r_alpha_or_tau, r_beta_or_inf in the c++ output and formatted equations Rearranges in the form (inf-x)/tau """ with self._lookup_tables.method_being_printed('EvaluateEquations'): return super()._get_formatted_alpha_beta() def format_deriv_eqs_EvaluateEquations(self, deriv_eqs_EvaluateEquations): """ Format derivative equations beloning to EvaluateEquations, to update what equation belongs were""" voltage_eqs = set( get_equations_for(self._model, [ d for d in self._model.y_derivatives if d.args[0] is self._model.membrane_voltage_var ])) other_eqs = set( get_equations_for(self._model, [ d for d in self._model.y_derivatives if d.args[0] is not self._model.membrane_voltage_var ])) voltage_eqs -= set(other_eqs) self._derivative_eqs_voltage |= voltage_eqs return super().format_deriv_eqs_EvaluateEquations( deriv_eqs_EvaluateEquations)