def planar_wave(ode, domain, D, S1, threshold, d=5): ''' Finds the conduction velocity of a planar wave. First paces the 0D ode model for 50 cycles at a BCL of 1000, then initiates a single planar wave, measuring the conduction velocity half-way in the domain. ''' # Load the cell model from file cellmodel = load_ode(ode) cellmodel = dolfin_jit(cellmodel, field_states, field_parameters) # Create the CardiacModel for the given domain and cell model heart = CardiacModel(domain, Constant(0.0), D, None, cellmodel) # Create the solver solver = GOSSplittingSolver(heart, GOSSparams()) dolfin_solver = solver.ode_solver ode_solver = dolfin_solver._ode_system_solvers[0] BCL = 1000 try: sspath = "../data/steadystates/" steadystate = np.load(sspath+"steadystate_%s_BCL%d.npy" % (ode, BCL)) except: print "Did not find steadystate for %s at BCL: %g, pacing 0D model" % (ode, BCL) steadystate = find_steadystate(ODE, BCL, 0.01) # Load steadystate into 2D model (u, um) = solver.solution_fields() for node in range(domain.coordinates().shape[0]): ode_solver.states(node)[:] = steadystate u.vector()[node] = steadystate[0] # Define the planar wave stimulus V = VectorFunctionSpace(domain, 'CG', 1) S1 = interpolate(S1, V).vector().array() # Apply the stimulus ode_solver.set_field_parameters(S1) xp = 0; yp = 0; t0 = 0; t1 = 0; for timestep, (u, vm) in solver.solve((0, tstop), dt): #plot(u, **plot_args) x = u.vector()[N/2-d] y = u.vector()[N/2+d] print u.vector().max(), x, y if x > threshold and (x-threshold)*(xp-threshold) < 0: # X cell has activated t0 = timestep[0] if y > threshold and (y-threshold)*(yp-threshold) < 0 and t0 != 0: # Y cell has activated t1 = timestep[0] # Calculate conduction velocity print "CV: ", 2*d*h/(t1-t0)*1e3 break xp = x yp = y
def module(): # Load reference CodeComponents ode = gotran.load_ode(_here.joinpath("tentusscher_2004_mcell_updated.ode")) # Options for code generation # Python code generator with default code generation paramters codegen = gotran.PythonCodeGenerator(default_params) module_code = codegen.module_code(ode) spec = importlib.util.spec_from_loader("ode", loader=None) _module = importlib.util.module_from_spec(spec) exec(module_code, _module.__dict__) return _module
def __init__(self, **kwargs): COSSTestCase.check_kwargs(**kwargs) params = COSSTestCase.default_parameters() params.update(kwargs) if "field_states" in params and isinstance(params["field_states"], dict): params["field_states"] = params["field_states"][ params["ode_model"]] if "field_parameters" in params and isinstance( params["field_parameters"], dict, ): params["field_parameters"] = params["field_parameters"][ params["ode_model"]] self.__dict__.update(params) print(self.ode_model) self.ode = load_ode(self.ode_model) self.stored_field_states = list() if self.field_states_getter_fn is None: self.stored_field_states = None self.field_states_fn = None else: self.stored_field_states = list() self.field_states_fn = self.field_states_getter_fn( self.num_nodes, self.stored_field_states, ) initial_field_params = [ param.init for param in self.ode.parameters if param.name in self.field_parameters ] if self.field_parameter_values_getter_fn is None: self.field_parameter_values = None else: self.field_parameter_values = self.field_parameter_values_getter_fn( initial_field_params, self.num_nodes, self.double, )
def test_extract_components(): ode = gotran.load_ode(_here.joinpath("tentusscher_2004_mcell_updated.ode")) potassium = ode.extract_components( "Potassium", "Rapid time dependent potassium current", "Inward rectifier potassium current", "Slow time dependent potassium current", "Potassium pump current", "Potassium dynamics", "Transient outward current", ) for name, obj in list(potassium.present_ode_objects.items()): orig_obj = ode.present_ode_objects[name] assert orig_obj.param.value == pytest.approx(obj.param.value) sodium = ode.extract_components( "Sodium", "Fast sodium current", "Sodium background current", "Sodium potassium pump current", "Sodium calcium exchanger current", "Sodium dynamics", ) for name, obj in list(sodium.present_ode_objects.items()): orig_obj = ode.present_ode_objects[name] assert orig_obj.param.value == pytest.approx(obj.param.value) calcium = ode.extract_components( "Calcium", "Calcium dynamics", "Calcium background current", "Calcium pump current", "L type ca current", ) for name, obj in list(calcium.present_ode_objects.items()): orig_obj = ode.present_ode_objects[name] assert orig_obj.param.value == pytest.approx(obj.param.value)
def __init__( self, ode, num_nodes, dt, tstop, t0=0.0, solver="rush_larsen", field_states=[""], field_states_getter_fn=None, field_parameters=[""], field_parameter_values_getter_fn=None, block_size=1024, double=True, statesrepr="named", paramrepr="named", bodyrepr="named", use_cse=False, update_host_states=False, update_field_states=True, ): self.ode = load_ode(ode) self.num_nodes = num_nodes self.dt = dt self.tstop = tstop self.t0 = t0 self.solver = solver self.field_states = field_states self.field_states_fn = (field_states_getter_fn(num_nodes) if field_states_getter_fn is not None else None) self.field_parameters = field_parameters self.field_parameter_values = (field_parameter_values_getter_fn( num_nodes, double) if field_parameter_values_getter_fn is not None else None) self.block_size = block_size self.double = double self.statesrepr = statesrepr self.paramrepr = paramrepr self.bodyrepr = bodyrepr self.use_cse = use_cse self.update_host_states = update_host_states self.update_field_states = update_field_states
def generate_code(gen_params=None): ode = gotran.load_ode(_here.joinpath("tentusscher_2004_mcell_updated.ode")) code_params = None if gen_params is None else gen_params.code codegen = gotran.PythonCodeGenerator(gen_params) jac = gotran.jacobian_expressions(ode, params=code_params) comps = [ gotran.rhs_expressions(ode, params=code_params), gotran.monitored_expressions( ode, ["i_NaK", "i_NaCa", "i_CaL", "d_fCa"], params=code_params, ), gotran.componentwise_derivative(ode, 15, params=code_params), gotran.linearized_derivatives(ode, params=code_params), jac, gotran.diagonal_jacobian_expressions(jac, params=code_params), gotran.jacobian_action_expressions(jac, params=code_params), ] return [codegen.function_code(comp) for comp in comps]
def __init__(self, ode): # Create the domain domain = RectangleMesh(0, 0, L, L, N, N) # Load the cell model from file cellmodel = load_ode(ode) cellmodel = dolfin_jit(cellmodel, field_states, field_parameters) # Create the CardiacModel for the given domain and cell model heart = CardiacModel(domain, Constant(0.0), D, None, cellmodel) # Create the solver solver = GOSSplittingSolver(heart, self.GOSSparams()) # Get the solution fields and subsolvers dolfin_solver = solver.ode_solver ode_solver = dolfin_solver._ode_system_solvers[0] self.solver = solver self.ode = ode self.dolfin_solver = dolfin_solver self.ode_solver = ode_solver self.domain = domain
BZ = -numpy.ones(mesh.num_vertices()) edx = numpy.nonzero(M > 10)[0] # The elements in SAN idx = numpy.unique(E[edx, 0:3]) # The nodes in these elements idx = idx.astype(int) BZ[idx] = 1 pv = Viper(meshlist, BZ) pv.interactive() idx0 = np.nonzero(BZ >= 0)[0] # SA cells idx1 = np.nonzero(BZ < 0)[0] # normal cells m0 = len(idx0) m1 = len(idx1) ode0 = jit(load_ode("difrancesco")) ode1 = jit(load_ode("myocyte.ode")) solver0 = GRL2() #ImplicitEuler() solver1 = GRL2() #ImplicitEuler() system_solver0 = ODESystemSolver(m0, solver0, ode0) system_solver1 = ODESystemSolver(m1, solver1, ode1) P = make_parameter_field(m0, ode0, {'distance': BZ[idx0, :]}) #P = make_parameter_field(m0, ode0, {'distance': 1.}) system_solver0.set_field_parameters(P) V0 = np.zeros(m0) system_solver0.get_field_states(V0)
axis = [-80., 40.] save_fig = True plot_args = { 'range_min': 0., 'range_max': 1., 'mode': 'color', #'window_width': 1366, #'window_height': 744 } ode = 'FK_cAF' domain = Mesh('atrial_mesh.xml') # Load the cell model from file cellmodel = load_ode(ode) cellmodel = dolfin_jit(cellmodel, field_states, field_parameters) def GOSSparams(): params = GOSSplittingSolver.default_parameters() params["pde_solver"] = "monodomain" params["MonodomainSolver"]["linear_solver_type"] = "iterative" params["MonodomainSolver"]["theta"] = 1.0 params["ode_solver"]["scheme"] = "RL1" params["apply_stimulus_current_to_pde"] = False return params # Create the CardiacModel for the given domain and cell model heart = CardiacModel(domain, Constant(0.0), D, None, cellmodel)
def pacing_cv(ode, BCL_range, D, dt, threshold=0, stim_amp=0, plot_args=None): D = Constant(D) # Must be adjusted with temporal and spatial resolution L = 20 h = 0.5 dt_low = 0.1 # Create dolfin mesh N = int(L/h) domain = IntervalMesh(N, 0, L) # Load the cell model from file ode = load_ode(ode) cellmodel = dolfin_jit(ode, field_states=["V"], field_parameters=["stim_period", "stim_amplitude"]) # Create the CardiacModel for the given mesh and cell model heart = CardiacModel(domain, Constant(0.0), D, None, cellmodel) # Create the solver solver = GOSSplittingSolver(heart, GOSSparams()) # Get the solution fields and subsolvers dolfin_solver = solver.ode_solver ode_solver = dolfin_solver._ode_system_solvers[0] results = np.zeros((4, len(BCL_range))) for i, BCL in enumerate(BCL_range): # Set the stimulus parameter field stim_field = np.zeros(2*(N+1), dtype=np.float_) stim_field[0] = BCL stim_field[1] = stim_amp ode_solver.set_field_parameters(stim_field) # Pace 0D cell model and find quasi steady state sspath = "../data/steadystates/" try: states = np.load(sspath+"steadystate_%s_BCL%d.npy" % (ode, BCL)) except: print "Did not find steadystate for %s at BCL: %g, pacing 0D model" % (ode, BCL) find_steadystate(ODE, BCL, 0.01) # Load quasi steady state into 1D model (u, um) = solver.solution_fields() for node in range(N+1): ode_solver.states(node)[:] = states u.vector()[node] = states[0] # Used for measuring cell activation xp = 0; yp=0 results[0][i] = BCL print "BCL: %g" % BCL # Do 3 pulses for pulsenr in range(1,4): # Solve the pulse in higher temporal resolution for timestep, (u, vm) in solver.solve((i*BCL, i*BCL+50), dt): if plot_args: plot(u, **plot_args) x = u.vector()[20] y = u.vector()[35] if x > threshold and (x-threshold)*(xp-threshold) < 0: t0 = (threshold - xp)/(x - xp)*dt + tp if y > threshold and (y-threshold)*(yp-threshold) < 0: t1 = (threshold - yp)/(y - yp)*dt + tp # Calculate conduction velocity Cv = 15*h/(t1-t0)*1e3 print "\tCv: %g" % Cv results[pulsenr][i] = Cv t0 = 0; t1 = 0 xp = x yp = y tp = timestep[0] # Wait for next pulse for timestep, (u, vm) in solver.solve((i*BCL+50, (i+1)*BCL), dt_low): if plot_args: plot(u, **plot_args) np.save("../data/results/cv_%s"%ode, results)
vertex_to_dof_map = V.dofmap().vertex_to_dof_map(mesh) fenics_ordered_coordinates = mesh.coordinates()[vertex_to_dof_map] N_thread = fenics_ordered_coordinates.shape[0] BZ = find_leaf_nodes() ### p now contains the distances to the leaf nodes idx = BZ >= 0; celltype = np.ones(N_thread) celltype[idx] = 0; idx0 = np.nonzero(celltype==0)[0] # SA cells idx1 = np.nonzero(celltype==1)[0] # normal cells ode0 = jit(load_ode("difrancesco.ode")) ode1 = jit(load_ode("myocyte.ode")) solvermethod0 = ImplicitEuler() solvermethod1 = ImplicitEuler() m0 = len(idx0) m1 = len(idx1) N = m0+m1 ode_solver0 = ODESystemSolver(m0, solvermethod0, ode0) ode_solver1 = ODESystemSolver(m1, solvermethod1, ode1) goss_wrap0 = Goss_wrapper(ode_solver0, advance0, V) goss_wrap1 = Goss_wrapper(ode_solver1, advance1, V)
def setup_model(cellmodel_strs, domain, space): L = domain.coordinates().max() family, degree = family_and_degree_from_str(space) V = FunctionSpace(domain, family, degree) Vs = FunctionSpace(domain, "P", 1) field_parameters = dict(rice_model_2008=["hf", "hb"], hybrid=["TMon_coop", "TMon_pow"]) field_states = dict(rice_model_2008=["active"], hybrid=["active"]) #, "TCa" # Alter spatially varying paramters: param_scale = Expression("offset+scale*exp(-((x[0]-center_x)*(x[0]-center_x)+"\ "(x[1]-center_y)*(x[1]-center_y))/(sigma*sigma))", center_x=3*L/4, center_y=L/4, offset=0.0, sigma=L/2, \ scale=1.0) cellmodels = [dolfin_jit(\ load_ode(cellmodel), field_states=field_states[cellmodel], \ field_parameters=field_parameters[cellmodel]) for cellmodel in cellmodel_strs] if "rice_model_2008" in cellmodel_strs: max_value = 0.14 index = cellmodel_strs.index("rice_model_2008") model = cellmodels[index] values = dict(hf=0.03, hb=0.06) y_shift = dict(hf=2.5, hb=3) for param in ["hf", "hb"]: p0 = model.get_parameter(param) param_scale.offset = p0 param_scale.scale = -(p0 - values[param]) param_scale.center_y = y_shift[param] * L / 4 p_func = Function(V, name=param) p_func.interpolate(param_scale) model.set_parameter(param, p_func) if "hybrid" in cellmodel_strs: max_value = 0.14 index = cellmodel_strs.index("hybrid") model = cellmodels[index] values = dict(TMon_coop=0.5, TMon_pow=0.5) y_shift = dict(TMon_coop=1.5, TMon_pow=2.0) for param in values.keys(): p0 = model.get_parameter(param) param_scale.offset = p0 param_scale.scale = -(p0 - values[param]) param_scale.center_y = y_shift[param] * L / 4 p_func = Function(V, name=param) p_func.interpolate(param_scale) model.set_parameter(param, p_func) if len(cellmodels) == 2: subdomain = CompiledSubDomain("x[1] <= 0.5") if space == "P_1": cellmodel_domains = VertexFunction("size_t", domain, 10) else: cellmodel_domains = CellFunction("size_t", domain, 10) subdomain.mark(cellmodel_domains, 20) cellmodels = dict(label_model for label_model in zip([10, 20], cellmodels)) elif len(cellmodels) == 1: cellmodels = cellmodels[0] cellmodel_domains = None else: assert False params = DOLFINODESystemSolver.default_parameters() params["solver"] = "RL1" solver = DOLFINODESystemSolver(domain, cellmodels, domains=cellmodel_domains, \ space=space, params=params) u = Function(solver.state_space) solver.from_field_states(u) if space != "P1": u_plot = Function(Vs) if solver.num_field_states > 1: u_plot.assign(project(u[0], Vs)) else: u_plot.assign(project(u, Vs)) elif solver.num_field_states > 1: u_plot = u.split(True)[0] else: u_plot = u t = 0 dt = .1 tstop = 300 while t < tstop: solver.step((t, t + dt), u) if (t % 10.) < dt: print "t:", t if space != "P1": if solver.num_field_states > 1: u_plot.assign(project(u[0], Vs)) else: u_plot.assign(project(u, Vs)) elif solver.num_field_states > 1: assign(u_plot, u.sub(0)) plot(u_plot, scale=0., range_max=max_value, range_min=0.) t += dt plot(u_plot, scale=0., range_max=max_value, range_min=0., interactive=True)
#import os from gotran import load_ode from goss import * import numpy as np oscilator = jit(load_ode("oscilator")) #solver1 = ExplicitEuler(oscilator) solver1 = RL1(oscilator) tstop = 10. n_steps = 1000 time = np.linspace(0, tstop, n_steps + 1) dt = time[1] - time[0] u0 = np.array([1.0, 0.0]) u_exact = np.array([np.cos(time), np.sin(time)]) u1 = np.zeros_like(u_exact) u2 = np.zeros_like(u_exact) u1[:, 0] = u0 u2[:, 0] = u0 print(u1) u = u0 for step in range(1, n_steps + 1): #u = u1[step] t = time[step] solver1.forward(u, t, dt) u1[:, step] = u
class TestODESystemSolver(object): ode = jit(load_ode("tentusscher_2004_mcell"), field_states=["V", "Ca_i"], field_parameters=["g_CaL", "K_o"]) def test_ode_interface(self): assert self.ode.num_field_states() == 2 assert self.ode.num_field_parameters() == 2 assert self.ode.num_states() == 17 assert self.ode.num_parameters() == 45 Cm = self.ode.get_parameter("Cm") assert Cm == 0.185 self.ode.set_parameter("Cm", 0.2) assert self.ode.get_parameter("Cm") == 0.2 self.ode.set_parameter("Cm", 0.185) assert isinstance(self.ode.get_state_names(), list) assert isinstance(self.ode.get_parameter_names(), list) assert len(self.ode.get_state_names()) == 17 assert len(self.ode.get_parameter_names()) == 45 assert "stim_start" in self.ode.get_parameter_names() state_names = [ "Xr1", "Xr2", "Xs", "m", "h", "j", "d", "f", "fCa", "s", "r", "g", "Ca_i", "Ca_SR", "Na_i", "V", "K_i" ] state_names.sort() ode_state_names = self.ode.get_state_names() ode_state_names.sort() assert ode_state_names == state_names with pytest.raises(RuntimeError): self.ode.set_parameter("cm", 0.185) with pytest.raises(RuntimeError): self.ode.get_parameter("JADA") def test_ode_system_interface(self): num_nodes = 100 solver = RL1() system = ODESystemSolver(num_nodes, solver, self.ode) system.reset_default() field_states = np.zeros(self.ode.num_field_states() * num_nodes) field_parameters = np.zeros(self.ode.num_field_parameters() * num_nodes) assert isinstance(system.states(), np.ndarray) assert len(system.states()) == num_nodes * self.ode.num_states() # Check tangled access system.get_field_states(field_states, True) assert sum(field_states[::2] == -86.2) == num_nodes assert sum(field_states[1::2] == 0.0002) == num_nodes # Check untangled access system.get_field_states(field_states, False) assert sum(field_states[:num_nodes] == -86.2) == num_nodes assert sum(field_states[num_nodes:] == 0.0002) == num_nodes assert system.num_nodes() == num_nodes with pytest.raises(ValueError): system.set_field_states( np.zeros(self.ode.num_field_states() * num_nodes // 2)) with pytest.raises(TypeError): system.set_field_states( np.zeros(self.ode.num_field_states() * num_nodes, dtype=int)) with pytest.raises(ValueError): system.set_field_parameters( np.zeros(self.ode.num_field_states() * num_nodes // 2)) with pytest.raises(TypeError): system.set_field_parameters(np.zeros(self.ode.num_field_states()*num_nodes, \ dtype=int))
field_parameters = ["stim_amplitude", "stim_offset", "stim_period"] axis = [-80., 40.] save_fig = True plot_args = {'range_min': 0., 'range_max': 1., 'mode': 'color', #'window_width': 1366, #'window_height': 744 } ode = 'FK_cAF' domain = Mesh('atrial_mesh.xml') # Load the cell model from file cellmodel = load_ode(ode) cellmodel = dolfin_jit(cellmodel, field_states, field_parameters) def GOSSparams(): params = GOSSplittingSolver.default_parameters() params["pde_solver"] = "monodomain" params["MonodomainSolver"]["linear_solver_type"] = "iterative" params["MonodomainSolver"]["theta"] = 1.0 params["ode_solver"]["scheme"] = "RL1" params["apply_stimulus_current_to_pde"] = False return params # Create the CardiacModel for the given domain and cell model heart = CardiacModel(domain, Constant(0.0), D, None, cellmodel) # Create the solver
def pacing_cv(ode, BCL_range, D, dt, threshold=0, stim_amp=0, plot_args=None): D = Constant(D) # Must be adjusted with temporal and spatial resolution L = 20 h = 0.5 dt_low = 0.1 # Create dolfin mesh N = int(L / h) domain = IntervalMesh(N, 0, L) # Load the cell model from file ode = load_ode(ode) cellmodel = dolfin_jit(ode, field_states=["V"], field_parameters=["stim_period", "stim_amplitude"]) # Create the CardiacModel for the given mesh and cell model heart = CardiacModel(domain, Constant(0.0), D, None, cellmodel) # Create the solver solver = GOSSplittingSolver(heart, GOSSparams()) # Get the solution fields and subsolvers dolfin_solver = solver.ode_solver ode_solver = dolfin_solver._ode_system_solvers[0] results = np.zeros((4, len(BCL_range))) for i, BCL in enumerate(BCL_range): # Set the stimulus parameter field stim_field = np.zeros(2 * (N + 1), dtype=np.float_) stim_field[0] = BCL stim_field[1] = stim_amp ode_solver.set_field_parameters(stim_field) # Pace 0D cell model and find quasi steady state sspath = "../data/steadystates/" try: states = np.load(sspath + "steadystate_%s_BCL%d.npy" % (ode, BCL)) except: print "Did not find steadystate for %s at BCL: %g, pacing 0D model" % ( ode, BCL) find_steadystate(ODE, BCL, 0.01) # Load quasi steady state into 1D model (u, um) = solver.solution_fields() for node in range(N + 1): ode_solver.states(node)[:] = states u.vector()[node] = states[0] # Used for measuring cell activation xp = 0 yp = 0 results[0][i] = BCL print "BCL: %g" % BCL # Do 3 pulses for pulsenr in range(1, 4): # Solve the pulse in higher temporal resolution for timestep, (u, vm) in solver.solve((i * BCL, i * BCL + 50), dt): if plot_args: plot(u, **plot_args) x = u.vector()[20] y = u.vector()[35] if x > threshold and (x - threshold) * (xp - threshold) < 0: t0 = (threshold - xp) / (x - xp) * dt + tp if y > threshold and (y - threshold) * (yp - threshold) < 0: t1 = (threshold - yp) / (y - yp) * dt + tp # Calculate conduction velocity Cv = 15 * h / (t1 - t0) * 1e3 print "\tCv: %g" % Cv results[pulsenr][i] = Cv t0 = 0 t1 = 0 xp = x yp = y tp = timestep[0] # Wait for next pulse for timestep, (u, vm) in solver.solve( (i * BCL + 50, (i + 1) * BCL), dt_low): if plot_args: plot(u, **plot_args) np.save("../data/results/cv_%s" % ode, results)
from gotran import CCodeGenerator from gotran import load_ode from gotran import ODERepresentation drv.init() dev = drv.Device(0) dev = drv.Device(0) arch = "sm_%d%d" % dev.compute_capability() optimisations = dict() filename = "tentusscher_panfilov_2006_M_cell_continuous" ode = load_ode(filename) # Get num states and parameters which sets the offset into the state # and parameter array num_states = ode.num_states num_params = ode.num_parameters oderepr = ODERepresentation(ode, **optimisations) ccode = CCodeGenerator(oderepr) init_state_code = (ccode.init_states_code().replace( "void", "__global__ void").replace( "{", "{\n const int thread_ind = blockIdx.x*blockDim.x + threadIdx.x;" "\n const int offset = thread_ind*%d;" % num_states,
def pacing_cv(ode, BCL_range, D, L, h, dt): """ Pace a 0D cell model for 50 cycles at given BCL, then simulate a wave in a 1D strand. 5 beats are initiated at the left hand side of the strand, after the fifth beat, the conduction velocity is measured. """ # Create domain N = int(L / h) # number of nodes - 1 domain = IntervalMesh(N, 0, L) # Define functions V = FunctionSpace(domain, 'Lagrange', 1) u = TrialFunction(V) up = Function(V) v = TestFunction(V) # Assemble the mass matrix and the stiffness matrix M = assemble(u * v * dx) K = assemble(Constant(dt) * inner(D * grad(u), grad(v)) * dx) A = M + K pde_solver = LUSolver(A) pde_solver.parameters["reuse_factorization"] = True # Set up ODESolver ode = load_ode(ode) compiled_ode = dolfin_jit(ode, field_states=["V"], field_parameters=["stim_period"]) params = DOLFINODESystemSolver.default_parameters() params["scheme"] = "RL1" stim_field = interpolate(Expression("near(x[0],0)*sd", sd=1.), V) #compiled_ode.set_parameter("stim_duration", stim_field) BCL = 1000 #compiled_ode.set_parameter("stim_period", BCL) ode_solver = DOLFINODESystemSolver(domain, compiled_ode, params=params) solver = ode_solver._ode_system_solvers[0] try: states = np.load("../data/steadystate_%s_BCL%d.npy" % (ode, BCL)) except: states = find_steadystate(BCL, 50, dt, ode, plot_results=False) for i in range(N + 1): solver.states(i)[:] = states up.vector()[i] = solver.states(i)[0] #params = np.zeros((N+1)*2) solver.set_field_parameters(np.array([BCL] + [0] * N, dtype=np.float_)) #stim_field = interpolate(Expression("0"), V) #compiled_ode.set_parameter("stim_duration", stim_field) t = 0. tstop = 1e6 while t < tstop: # Step ODE solver ode_solver.step((t, t + dt), up) # Assemble RHS and solve pde b = M * up.vector() pde_solver.solve(up.vector(), b) # Plot solution plot(up, range_min=-80.0, range_max=40.0) t += dt
from gotran import load_ode from goss import * ode = jit(load_ode('myocyte.ode')) def f(x): return x class Tull: def __init__(self): self.a = 2 a = f b = Tull() print isinstance(b, Tull)
D = as_tensor([[Dx, 0.], [0, Dy]]) dt = 0.125 tstop = 25.0 a = Constant(1.) V_init = -85. V_amp = 85. t = Constant(0.) # Domain and solution space do_plot = False L = 100. N = 1024 #N = 128 domain = RectangleMesh(-L, -L, L, L, N, N) cellmodel = dolfin_jit(load_ode("tentusscher_panfilov_2006_M_cell.ode"), field_states=["V"]) heart = CardiacModel(domain, t, D, None, cellmodel) ps = GOSSplittingSolver.default_parameters() ps["pde_solver"] = "monodomain" ps["MonodomainSolver"]["linear_solver_type"] = "iterative" ps["MonodomainSolver"]["theta"] = 1.0 ps["ode_solver"]["solver"] = "RL1" ps["ode_solver"]["num_threads"] = 8 / MPI.size(domain.mpi_comm()) # If cuda ps["ode_solver"]["use_cuda"] = True ps["ode_solver"]["cuda_params"]["float_precision"] = "double" ps["ode_solver"]["cuda_params"]["solver"] = "rush_larsen"
def test_subode(): ode_from_file = gotran.load_ode( _here.joinpath("tentusscher_2004_mcell_updated")) ode = gotran.ODE("Tentusscher_2004_merged") # Add parameters and states ode.add_parameters( Na_i=ScalarParam(11.6), Na_o=ScalarParam(140), K_i=ScalarParam(138.3), K_o=ScalarParam(5.4), Ca_o=ScalarParam(2), Ca_i=ScalarParam(0.0002), ) mem = ode("Membrane") mem.add_states(V=ScalarParam(-86.2)) mem.add_parameters( Cm=ScalarParam(0.185), F=ScalarParam(96485.3415), R=ScalarParam(8314.472), T=ScalarParam(310), V_c=ScalarParam(0.016404), stim_amplitude=ScalarParam(0), stim_duration=ScalarParam(1), stim_period=ScalarParam(1000), stim_start=ScalarParam(1), ) rev_pot = ode("Reversal potentials") rev_pot.add_parameter("P_kna", ScalarParam(0.03)) # Add intermediates ode.i_Stim = ( -mem.stim_amplitude * (1 - 1 / (1 + exp(5.0 * ode.t - 5.0 * mem.stim_start))) / (1 + exp(5.0 * ode.t - 5.0 * mem.stim_start - 5.0 * mem.stim_duration))) rev_pot.E_Na = mem.R * mem.T * log(ode.Na_o / ode.Na_i) / mem.F rev_pot.E_K = mem.R * mem.T * log(ode.K_o / ode.K_i) / mem.F rev_pot.E_Ks = (mem.R * mem.T * log( (ode.Na_o * rev_pot.P_kna + ode.K_o) / (ode.K_i + ode.Na_i * rev_pot.P_kna), ) / mem.F) rev_pot.E_Ca = 0.5 * mem.R * mem.T * log(ode.Ca_o / ode.Ca_i) / mem.F # Get the E_Na expression and one dependency E_Na = ode.present_ode_objects["E_Na"] Na_i = ode.present_ode_objects["Na_i"] assert isinstance(Na_i, gotran.Parameter) # Check dependencies assert Na_i in ode.expression_dependencies[E_Na] assert E_Na in ode.object_used_in[Na_i] ode.import_ode(_here.joinpath("Sodium")) # Check dependencies after sub ode has been loaded E_Na = ode.present_ode_objects["E_Na"] Na_i_new = ode.present_ode_objects["Na_i"] assert isinstance(Na_i_new, gotran.State) assert Na_i not in ode.expression_dependencies[E_Na] assert Na_i_new in ode.expression_dependencies[E_Na] assert E_Na not in ode.object_used_in[Na_i] assert E_Na in ode.object_used_in[Na_i_new] pot = gotran.load_ode(_here.joinpath("Potassium")) ode.import_ode(pot, prefix="pot") pot_comps = [comp.name for comp in pot.components] # Add sub ode by extracting components ode.import_ode( ode_from_file, components=[ "Calcium dynamics", "Calcium background current", "Calcium pump current", "L type ca current", ], ) # Get all the Potassium currents i_K1 = ode("Inward rectifier potassium current").pot_i_K1 i_Kr = ode("Rapid time dependent potassium current").pot_i_Kr i_Ks = ode("Slow time dependent potassium current").pot_i_Ks i_to = ode("Transient outward current").pot_i_to i_p_K = ode("Potassium pump current").pot_i_p_K # Get all the Sodium currents i_Na = ode("Fast sodium current").i_Na i_b_Na = ode("Sodium background current").i_b_Na i_NaCa = ode("Sodium calcium exchanger current").i_NaCa i_NaK = ode("Sodium potassium pump current").i_NaK # Get all the Calcium currents i_CaL = ode("L type ca current").i_CaL i_b_Ca = ode("Calcium background current").i_b_Ca i_p_Ca = ode("Calcium pump current").i_p_Ca # Membrane potential derivative mem.dV_dt = (-i_Ks - i_p_K - i_Na - i_K1 - i_p_Ca - i_b_Ca - i_NaK - i_CaL - i_Kr - ode.i_Stim - i_NaCa - i_b_Na - i_to) # Finalize ODE ode.finalize() for name, obj in list(ode.present_ode_objects.items()): # If object in prefixed potassium components if ode.object_component[obj].name in pot_comps: loaded_obj = ode_from_file.present_ode_objects[name.replace( "pot_", "")] else: loaded_obj = ode_from_file.present_ode_objects[name] assert type(obj) == type(loaded_obj) assert loaded_obj.param.value == pytest.approx(obj.param.value)
def spiral(ode, domain, D, S1, S2, plot_args={}, save_fig=False, threshold=0, BCL=500): # Load the cell model from file cellmodel = load_ode(ode) cellmodel = dolfin_jit(cellmodel, field_states, field_parameters) # Create the CardiacModel for the given domain and cell model heart = CardiacModel(domain, Constant(0.0), D, None, cellmodel) # Create the solver solver = GOSSplittingSolver(heart, GOSSparams()) dolfin_solver = solver.ode_solver ode_solver = dolfin_solver._ode_system_solvers[0] # Calculate parameter fields V = VectorFunctionSpace(domain, 'CG', 1) S1 = interpolate(S1, V).vector().array() S2 = interpolate(S2, V).vector().array() nostim = np.zeros(S1.shape, dtype=np.float_) # Read in steady state from file try: sspath = "../data/steadystates/" steadystate = np.load(sspath + "steadystate_%s_BCL%d.npy" % (ode, BCL)) except: print "Did not find steadystate for %s at BCL: %g, pacing 0D model" % ( ode, BCL) steadystate = find_steadystate(ODE, BCL, 0.01) # Load steadystate into 2D model (u, um) = solver.solution_fields() for node in range(domain.coordinates().shape[0]): ode_solver.states(node)[:] = steadystate u.vector()[node] = steadystate[0] cnt = 0 # Stimulate S1 pulse plot(u, interactive=True, **plot_args) ode_solver.set_field_parameters(S1) for timestep, (u, vm) in solver.solve((0, S2_time), dt): fig = plot(u, interactive=False, **plot_args) if save_fig and cnt % 5 == 0: padded_index = '%08d' % cnt fig.write_png('../tmp/randfib2_spiral_%s' % ode + padded_index) cnt += 1 ode_solver.set_field_parameters(S2) # Stimulate S2 pulse for timestep, (u, vm) in solver.solve((S2_time, S2_time + 100), dt): fig = plot(u, interactive=False, **plot_args) if save_fig and cnt % 5 == 0: padded_index = '%08d' % cnt fig.write_png('../tmp/randfib2_spiral_%s' % ode + padded_index) cnt += 1 # Turn of all stimulus and iterate until simulation stop # We monitor the top right corner to find the randfib_spiral period ode_solver.set_field_parameters(nostim) trig_time = [] xp = 0 for timestep, (u, vm) in solver.solve((S2_time + 100, tstop), dt): fig = plot(u, interactive=False, **plot_args) x = u.vector()[-1] if x > threshold and (x - threshold) * (xp - threshold) < 0: trig_time.append(timestep[0]) print "Trigger: %g" % timestep[0] xp = x if save_fig and cnt % 5 == 0: padded_index = '%08d' % cnt fig.write_png('../tmp/randfib2_spiral_%s' % ode + padded_index) cnt += 1
class TestODESolvers(object): orders = dict(RK4=4, RK2=2, RL1=1, RL2=2, GRL1=1, GRL2=2, ExplicitEuler=1, ThetaSolver=2, BasicImplicitEuler=1, RKF32=3, ESDIRK23a=3, ESDIRK4O32=4, ImplicitEuler=1) oscilator = jit(load_ode("oscilator")) exclude_osc = ["ESDIRK4O32", "ESDIRK23a"] exclude_tent = ["RKF32", "ESDIRK4O32", "ESDIRK23a", "BasicImplicitEuler"] dir_path = os.path.dirname(__file__) #Vm_reference = np.fromfile(os.path.join(dir_path, "Vm_reference.npy")) tentusscher = jit(load_ode("tentusscher_2004_mcell")) @parametrize(("solver_str"), goss_solvers) def test_convergence_order(self, solver_str): if solver_str in self.exclude_osc: return solver = eval(solver_str)(self.oscilator) tstop = 10. exact = np.array([np.cos(tstop), np.sin(tstop)]) errors = [] for dt in [0.05, 0.025, 0.0125, 0.00625]: u = np.array([1.0, 0.]) t = 0.0 nsteps = int(tstop / dt) for step in range(nsteps): solver.forward(u, t, dt) t += dt errors.append(np.sqrt(np.sum(((exact - u) / exact)**2))) assert min(convergence_order(errors)) >= self.orders[solver_str] - 0.1 @parametrize(("solver_str"), goss_solvers) def test_long_run(self, solver_str): if solver_str in self.exclude_tent: return dt = 0.0002 tstop = 10 ind_V = self.tentusscher.get_state_names().index("V") dt_ref = 0.1 solver = eval(solver_str)(self.tentusscher) self.tentusscher.set_parameter("stim_amplitude", 52.) self.tentusscher.set_parameter("stim_start", 0.) u = self.tentusscher.init_state_values() t = 0.0 nsteps = int(tstop / dt) for step in range(nsteps): solver.forward(u, t, dt) t += dt # Test against run with scipy integrate assert abs(u[ind_V] - 12.948) < 1e-3
def _test_codegeneration( module, body_repr=default_params["code"]["body"]["representation"], body_optimize="none", param_repr="named", state_repr="named", use_cse=False, use_enum=False, float_precision="double", ): parameters_name = "parameters" states_name = "states" body_array_name = "body" body_in_arg = False # The test that will be attached to the TestCase class below states_values = module.init_state_values() parameter_values = module.init_parameter_values() rhs_ref_values = module.rhs(states_values, 0.0, parameter_values) jac_ref_values = module.compute_jacobian(states_values, 0.0, parameter_values) gen_params = parameters["generation"].copy() # Update code_params code_params = gen_params["code"] code_params["body"]["optimize_exprs"] = body_optimize code_params["body"]["representation"] = body_repr code_params["body"]["in_signature"] = body_in_arg code_params["body"]["array_name"] = body_array_name code_params["parameters"]["representation"] = param_repr code_params["parameters"]["array_name"] = parameters_name code_params["states"]["representation"] = state_repr code_params["states"]["array_name"] = states_name code_params["body"]["use_cse"] = use_cse code_params["body"]["use_enum"] = use_enum code_params["float_precision"] = float_precision code_params["default_arguments"] = "stp" # Reload ODE for each test ode = gotran.load_ode(_here.joinpath("tentusscher_2004_mcell_updated.ode")) codegen = gotran.PythonCodeGenerator(gen_params) rhs_comp = gotran.rhs_expressions(ode, params=code_params) rhs_code = codegen.function_code(rhs_comp) rhs_namespace = {} exec(rhs_code, rhs_namespace) args = [states_values, 0.0] if param_repr != "numerals": args.append(parameter_values) if body_in_arg: body = np.zeros(rhs_comp.shapes[body_array_name]) args.append(body) # Call the generated rhs function rhs_values = rhs_namespace["rhs"](*args) rhs_norm = np.sqrt(np.sum(rhs_ref_values - rhs_values) ** 2) eps = 1e-8 if float_precision == "double" else 1e-6 assert rhs_norm < eps # Only evaluate jacobian if using full body_optimization and body repr is reused_array if ( body_optimize != "numerals_symbols" and body_repr != "reused_array" and param_repr == "named" ): return jac_comp = gotran.jacobian_expressions(ode, params=code_params) jac_code = codegen.function_code(jac_comp) jac_namespace = {} exec(jac_code, jac_namespace) args = [states_values, 0.0] if param_repr != "numerals": args.append(parameter_values) if body_in_arg: body = np.zeros(jac_comp.shapes[body_array_name]) args.append(body) jac_values = jac_namespace["compute_jacobian"](*args) jac_norm = np.sqrt(np.sum(jac_ref_values - jac_values) ** 2) eps = 1e-8 if float_precision == "double" else 1e-3 assert jac_norm < eps
N = (x_nodes + 1) * (y_nodes + 1) T = 100 dt = 0.5 t = 0 time_steps = int((T - t) / dt) time_solution_method = 'BE' ### crank nico save = False #save solutions as binary savemovie = False #create movie from results. Takes time! plot_realtime = True # small hack mesh = UnitSquareMesh(x_nodes, y_nodes) space = FunctionSpace(mesh, 'Lagrange', 1) ### Setting up Goss/Gotran part ode = jit(load_ode("myocyte.ode")) vertex_to_dof_map = space.dofmap().vertex_to_dof_map(mesh) N_thread = mesh.coordinates().shape[0] print vertex_to_dof_map.shape, mesh.coordinates().shape ist = np.zeros(N_thread, dtype=np.float_) ist = stimulation_domain(mesh.coordinates(), amp=-10) P0 = make_parameter_field(mesh.coordinates(), ode) P1 = make_parameter_field(mesh.coordinates(), ode, ist=ist) ind_stim = P1[:, 1] != 0 print "P0", P0[P0[:, 1] != 0., 1] print "P1", P1[ind_stim, 1] solver = ThetaSolver()
def test_creation(): # Adding a phoney ODE ode = gotran.ODE("test") # Add states and parameters j = ode.add_state("j", 1.0) i = ode.add_state("i", 2.0) k = ode.add_state("k", 3.0) ii = ode.add_parameter("ii", 0.0) jj = ode.add_parameter("jj", 0.0) kk = ode.add_parameter("kk", 0.0) # Try overwriting state with pytest.raises(gotran.GotranException): ode.add_parameter("j", 1.0) # Try overwriting parameter with pytest.raises(gotran.GotranException): ode.add_state("ii", 1.0) assert ode.num_states == 3 assert ode.num_parameters == 3 assert ode.present_component == ode # Add an Expression ode.alpha = i * j # Add derivatives for all states in the main component ode.add_comment("Some nice derivatives and an algebraic expression") ode.di_dt = ode.alpha + ii ode.dj_dt = -ode.alpha - jj ode.alg_k_0 = kk * k * ode.alpha assert ode.num_intermediates == 1 # Add a component with 2 states ode("jada").add_states(m=2.0, n=3.0, l=1.0, o=4.0) ode("jada").add_parameters(ll=1.0, mm=2.0) # Define a state derivative ode("jada").dm_dt = ode("jada").ll - (ode("jada").m - ode.i) jada = ode("jada") assert ode.present_component == jada # Test num_foo assert jada.num_states == 4 assert jada.num_parameters == 2 assert ode.num_states == 7 assert ode.num_parameters == 5 assert ode.num_components == 2 assert jada.num_components == 1 # Add expressions to the component jada.tmp = jada.ll * jada.m**2 + 3 / i - ii * jj jada.tmp2 = ode.j * exp(jada.tmp) # Reduce state n jada.add_solve_state(jada.n, 1 - jada.l - jada.m - jada.n) assert ode.num_intermediates == 4 # Try overwriting parameter with expression with pytest.raises(gotran.GotranException): jada.ll = jada.tmp * jada.tmp2 # Create a derivative expression ode.add_comment("More funky objects") jada.tmp3 = jada.tmp2.diff(ode.t) + jada.n + jada.o jada.add_derivative(jada.l, ode.t, jada.tmp3) jada.add_algebraic(jada.o, jada.o**2 - exp(jada.o) + 2 / jada.o) assert ode.num_intermediates == 9 assert ode.num_state_expressions == 6 assert ode.is_complete assert ode.num_full_states == 6 # Try adding expressions to ode component with pytest.raises(gotran.GotranException): ode.p = 1.0 # Check used in and dependencies for one intermediate tmp3 = ode.present_ode_objects["tmp3"] assert ode.object_used_in[tmp3] == {ode.present_ode_objects["dl_dt"]} for sym in symbols_from_expr(tmp3.expr, include_derivatives=True): assert (ode.present_ode_objects[sympycode(sym)] in ode.expression_dependencies[tmp3]) # Add another component to test rates bada = ode("bada") bada.add_parameters(nn=5.0, oo=3.0, qq=1.0, pp=2.0) nada = bada.add_component("nada") nada.add_states(("r", 3.0), ("s", 4.0), ("q", 1.0), ("p", 2.0)) assert bada.num_parameters == 4 assert bada.num_states == 4 nada.p = 1 - nada.r - nada.s - nada.q assert "".join(p.name for p in ode.parameters) == "iijjkkllmmnnooppqq" assert "".join(s.name for s in ode.states) == "jiklmnorsqp" assert not ode.is_complete # Add rates to component making it a Markov model component nada.rates[nada.r, nada.s] = 3 * exp(-i) # Try add a state derivative to Markov model with pytest.raises(gotran.GotranException): nada.ds_dt = 3.0 nada.rates[nada.s, nada.r] = 2.0 nada.rates[nada.s, nada.q] = 2.0 nada.rates[nada.q, nada.s] = 2 * exp(-i) nada.rates[nada.q, nada.p] = 3.0 nada.rates[nada.p, nada.q] = 4.0 assert ode.present_component == nada markov = bada.add_component("markov_2") markov.add_states(("tt", 3.0), ("u", 4.0), ("v", 1.0)) with pytest.raises(gotran.GotranException): markov.rates[nada.s, nada.r] = 2.0 with pytest.raises(gotran.GotranException): markov.rates[markov.tt, markov.u] = 2 * exp(markov.u) with pytest.raises(gotran.GotranException): markov.rates[[markov.tt, markov.u, markov.v]] = 5.0 with pytest.raises(gotran.GotranException): markov.rates[[markov.tt, markov.u, markov.v]] = Matrix([[1, 2 * i, 0.0], [0.0, 2.0, 4.0]], ) with pytest.raises(gotran.GotranException): markov.rates[markov.tt, markov.tt] = 5.0 markov.rates[[markov.tt, markov.u, markov.v]] = Matrix( [[0.0, 2 * i, 2.0], [4.0, 0.0, 2.0], [5.0, 2.0, 0.0]], ) ode.finalize() assert ode.is_complete assert ode.is_dae # Test Mass matrix vector = ode.mass_matrix * Matrix([1] * ode.num_full_states) assert (0, 0) == (vector[2], vector[5]) assert sum(ode.mass_matrix) == ode.num_full_states - 2 assert sum(vector) == ode.num_full_states - 2 assert "".join(s.name for s in ode.full_states) == "ijkmlorsqttuv" assert ode.present_component == ode # Test saving ode.save("test_ode") # Test loading ode_loaded = gotran.load_ode("test_ode") # Clean os.unlink("test_ode.ode") # Test same signature # self.assertEqual(ode.signature(), ode_loaded.signature()) # Check that all objects are the same and evaluates to same value for name, obj in list(ode.present_ode_objects.items()): loaded_obj = ode_loaded.present_ode_objects[name] assert type(obj) == type(loaded_obj) assert loaded_obj.param.value == pytest.approx(obj.param.value)
N = (x_nodes+1)*(y_nodes+1) T = 100 dt = 0.5 t = 0 time_steps = int((T-t)/dt) time_solution_method = 'BE' ### crank nico save = False #save solutions as binary savemovie = False #create movie from results. Takes time! plot_realtime = True # small hack mesh = UnitSquareMesh(x_nodes, y_nodes) space = FunctionSpace(mesh, 'Lagrange', 1) ### Setting up Goss/Gotran part ode = jit(load_ode("myocyte.ode")) vertex_to_dof_map = space.dofmap().vertex_to_dof_map(mesh) N_thread = mesh.coordinates().shape[0] print vertex_to_dof_map.shape, mesh.coordinates().shape ist = np.zeros(N_thread, dtype=np.float_) ist = stimulation_domain(mesh.coordinates(), amp= -10) P0 = make_parameter_field(mesh.coordinates(), ode) P1 = make_parameter_field(mesh.coordinates(), ode, ist=ist) ind_stim = P1[:,1]!=0 print "P0", P0[P0[:,1]!=0.,1] print "P1", P1[ind_stim,1] solver = GRL2()
def pacing_cv(ode, BCL_range, D, L, h, dt): """ Pace a 0D cell model for 50 cycles at given BCL, then simulate a wave in a 1D strand. 5 beats are initiated at the left hand side of the strand, after the fifth beat, the conduction velocity is measured. """ # Create domain N = int(L/h) # number of nodes - 1 domain = IntervalMesh(N, 0, L) # Define functions V = FunctionSpace(domain, 'Lagrange', 1) u = TrialFunction(V) up = Function(V) v = TestFunction(V) # Assemble the mass matrix and the stiffness matrix M = assemble(u*v*dx) K = assemble(Constant(dt)*inner(D*grad(u), grad(v))*dx) A = M + K pde_solver = LUSolver(A) pde_solver.parameters["reuse_factorization"] = True # Set up ODESolver ode = load_ode(ode) compiled_ode = dolfin_jit(ode, field_states=["V"], field_parameters = [ "stim_period"]) params = DOLFINODESystemSolver.default_parameters() params["scheme"] = "RL1" stim_field = interpolate(Expression("near(x[0],0)*sd", sd=1.), V) #compiled_ode.set_parameter("stim_duration", stim_field) BCL = 1000 #compiled_ode.set_parameter("stim_period", BCL) ode_solver = DOLFINODESystemSolver(domain, compiled_ode, params=params) solver = ode_solver._ode_system_solvers[0] try: states = np.load("../data/steadystate_%s_BCL%d.npy" % (ode, BCL)) except: states = find_steadystate(BCL, 50, dt, ode, plot_results=False) for i in range(N+1): solver.states(i)[:] = states up.vector()[i] = solver.states(i)[0] #params = np.zeros((N+1)*2) solver.set_field_parameters(np.array([BCL]+[0]*N, dtype=np.float_)) #stim_field = interpolate(Expression("0"), V) #compiled_ode.set_parameter("stim_duration", stim_field) t = 0. tstop = 1e6 while t < tstop: # Step ODE solver ode_solver.step((t, t+dt), up) # Assemble RHS and solve pde b = M*up.vector() pde_solver.solve(up.vector(), b) # Plot solution plot(up, range_min=-80.0, range_max=40.0) t += dt
def spiral(ode, domain, D, S1, S2, plot_args={}, save_fig=False, threshold=0, BCL=500): # Load the cell model from file cellmodel = load_ode(ode) cellmodel = dolfin_jit(cellmodel, field_states, field_parameters) # Create the CardiacModel for the given domain and cell model heart = CardiacModel(domain, Constant(0.0), D, None, cellmodel) # Create the solver solver = GOSSplittingSolver(heart, GOSSparams()) dolfin_solver = solver.ode_solver ode_solver = dolfin_solver._ode_system_solvers[0] # Calculate parameter fields V = VectorFunctionSpace(domain, 'CG', 1) S1 = interpolate(S1, V).vector().array() S2 = interpolate(S2, V).vector().array() nostim = np.zeros(S1.shape, dtype=np.float_) # Read in steady state from file try: sspath = "../data/steadystates/" steadystate = np.load(sspath+"steadystate_%s_BCL%d.npy" % (ode, BCL)) except: print "Did not find steadystate for %s at BCL: %g, pacing 0D model" % (ode, BCL) steadystate = find_steadystate(ODE, BCL, 0.01) # Load steadystate into 2D model (u, um) = solver.solution_fields() for node in range(domain.coordinates().shape[0]): ode_solver.states(node)[:] = steadystate u.vector()[node] = steadystate[0] cnt = 0 # Stimulate S1 pulse plot(u, interactive=True, **plot_args) ode_solver.set_field_parameters(S1) for timestep, (u, vm) in solver.solve((0, S2_time), dt): fig = plot(u, interactive=False, **plot_args) if save_fig and cnt % 5 == 0: padded_index = '%08d' % cnt fig.write_png('../tmp/spiral_%s' % ode + padded_index) cnt += 1 ode_solver.set_field_parameters(S2) # Stimulate S2 pulse for timestep, (u, vm) in solver.solve((S2_time, S2_time+100), dt): fig = plot(u, interactive=False, **plot_args) if save_fig and cnt % 5 == 0: padded_index = '%08d' % cnt fig.write_png('../tmp/spiral_%s' % ode + padded_index) cnt += 1 # Turn of all stimulus and iterate until simulation stop # We monitor the top right corner to find the spiral period ode_solver.set_field_parameters(nostim) trig_time = [] xp = 0; for timestep, (u, vm) in solver.solve((S2_time+100, tstop), dt): fig = plot(u, interactive=False, **plot_args) x = u.vector()[-1] if x > threshold and (x-threshold)*(xp-threshold) < 0: trig_time.append(timestep[0]) print "Trigger: %g" % timestep[0] xp = x if save_fig and cnt % 5 == 0: padded_index = '%08d' % cnt fig.write_png('../tmp/spiral_%s' % ode + padded_index) cnt += 1