예제 #1
0
def planar_wave(ode, domain, D, S1, threshold, d=5):
    '''
    Finds the conduction velocity of a planar wave. First paces the 0D ode
    model for 50 cycles at a BCL of 1000, then initiates a single planar
    wave, measuring the conduction velocity half-way in the domain.
    '''
    # Load the cell model from file
    cellmodel = load_ode(ode)
    cellmodel = dolfin_jit(cellmodel, field_states, field_parameters)

    # Create the CardiacModel for the given domain and cell model
    heart = CardiacModel(domain, Constant(0.0), D, None, cellmodel)

    # Create the solver
    solver = GOSSplittingSolver(heart, GOSSparams())
    dolfin_solver = solver.ode_solver
    ode_solver = dolfin_solver._ode_system_solvers[0]

    BCL = 1000
    try:
        sspath = "../data/steadystates/"
        steadystate = np.load(sspath+"steadystate_%s_BCL%d.npy" % (ode, BCL))
    except:
        print "Did not find steadystate for %s at BCL: %g, pacing 0D model" % (ode, BCL)
        steadystate = find_steadystate(ODE, BCL, 0.01)

    # Load steadystate into 2D model
    (u, um) = solver.solution_fields()
    for node in range(domain.coordinates().shape[0]):
        ode_solver.states(node)[:] = steadystate
        u.vector()[node] = steadystate[0]
    
    # Define the planar wave stimulus
    V = VectorFunctionSpace(domain, 'CG', 1)
    S1 = interpolate(S1, V).vector().array()

    # Apply the stimulus
    ode_solver.set_field_parameters(S1)

    xp = 0; yp = 0;
    t0 = 0; t1 = 0;
    for timestep, (u, vm) in solver.solve((0, tstop), dt):
        #plot(u, **plot_args)
        x = u.vector()[N/2-d]
        y = u.vector()[N/2+d]
        print u.vector().max(), x, y
        if x > threshold and (x-threshold)*(xp-threshold) < 0:
            # X cell has activated
            t0 = timestep[0]
        if y > threshold and (y-threshold)*(yp-threshold) < 0 and t0 != 0:
            # Y cell has activated
            t1 = timestep[0]
            # Calculate conduction velocity
            print "CV: ", 2*d*h/(t1-t0)*1e3
            break
        xp = x
        yp = y       
예제 #2
0
def module():
    # Load reference CodeComponents
    ode = gotran.load_ode(_here.joinpath("tentusscher_2004_mcell_updated.ode"))

    # Options for code generation

    # Python code generator with default code generation paramters
    codegen = gotran.PythonCodeGenerator(default_params)

    module_code = codegen.module_code(ode)
    spec = importlib.util.spec_from_loader("ode", loader=None)
    _module = importlib.util.module_from_spec(spec)
    exec(module_code, _module.__dict__)

    return _module
예제 #3
0
    def __init__(self, **kwargs):
        COSSTestCase.check_kwargs(**kwargs)
        params = COSSTestCase.default_parameters()
        params.update(kwargs)

        if "field_states" in params and isinstance(params["field_states"],
                                                   dict):
            params["field_states"] = params["field_states"][
                params["ode_model"]]

        if "field_parameters" in params and isinstance(
                params["field_parameters"],
                dict,
        ):
            params["field_parameters"] = params["field_parameters"][
                params["ode_model"]]

        self.__dict__.update(params)

        print(self.ode_model)
        self.ode = load_ode(self.ode_model)
        self.stored_field_states = list()
        if self.field_states_getter_fn is None:
            self.stored_field_states = None
            self.field_states_fn = None
        else:
            self.stored_field_states = list()
            self.field_states_fn = self.field_states_getter_fn(
                self.num_nodes,
                self.stored_field_states,
            )
        initial_field_params = [
            param.init for param in self.ode.parameters
            if param.name in self.field_parameters
        ]
        if self.field_parameter_values_getter_fn is None:
            self.field_parameter_values = None
        else:
            self.field_parameter_values = self.field_parameter_values_getter_fn(
                initial_field_params,
                self.num_nodes,
                self.double,
            )
예제 #4
0
def test_extract_components():

    ode = gotran.load_ode(_here.joinpath("tentusscher_2004_mcell_updated.ode"))

    potassium = ode.extract_components(
        "Potassium",
        "Rapid time dependent potassium current",
        "Inward rectifier potassium current",
        "Slow time dependent potassium current",
        "Potassium pump current",
        "Potassium dynamics",
        "Transient outward current",
    )

    for name, obj in list(potassium.present_ode_objects.items()):
        orig_obj = ode.present_ode_objects[name]
        assert orig_obj.param.value == pytest.approx(obj.param.value)

    sodium = ode.extract_components(
        "Sodium",
        "Fast sodium current",
        "Sodium background current",
        "Sodium potassium pump current",
        "Sodium calcium exchanger current",
        "Sodium dynamics",
    )

    for name, obj in list(sodium.present_ode_objects.items()):
        orig_obj = ode.present_ode_objects[name]
        assert orig_obj.param.value == pytest.approx(obj.param.value)

    calcium = ode.extract_components(
        "Calcium",
        "Calcium dynamics",
        "Calcium background current",
        "Calcium pump current",
        "L type ca current",
    )

    for name, obj in list(calcium.present_ode_objects.items()):
        orig_obj = ode.present_ode_objects[name]
        assert orig_obj.param.value == pytest.approx(obj.param.value)
 def __init__(
     self,
     ode,
     num_nodes,
     dt,
     tstop,
     t0=0.0,
     solver="rush_larsen",
     field_states=[""],
     field_states_getter_fn=None,
     field_parameters=[""],
     field_parameter_values_getter_fn=None,
     block_size=1024,
     double=True,
     statesrepr="named",
     paramrepr="named",
     bodyrepr="named",
     use_cse=False,
     update_host_states=False,
     update_field_states=True,
 ):
     self.ode = load_ode(ode)
     self.num_nodes = num_nodes
     self.dt = dt
     self.tstop = tstop
     self.t0 = t0
     self.solver = solver
     self.field_states = field_states
     self.field_states_fn = (field_states_getter_fn(num_nodes) if
                             field_states_getter_fn is not None else None)
     self.field_parameters = field_parameters
     self.field_parameter_values = (field_parameter_values_getter_fn(
         num_nodes, double) if field_parameter_values_getter_fn is not None
                                    else None)
     self.block_size = block_size
     self.double = double
     self.statesrepr = statesrepr
     self.paramrepr = paramrepr
     self.bodyrepr = bodyrepr
     self.use_cse = use_cse
     self.update_host_states = update_host_states
     self.update_field_states = update_field_states
예제 #6
0
    def generate_code(gen_params=None):
        ode = gotran.load_ode(_here.joinpath("tentusscher_2004_mcell_updated.ode"))
        code_params = None if gen_params is None else gen_params.code
        codegen = gotran.PythonCodeGenerator(gen_params)
        jac = gotran.jacobian_expressions(ode, params=code_params)

        comps = [
            gotran.rhs_expressions(ode, params=code_params),
            gotran.monitored_expressions(
                ode,
                ["i_NaK", "i_NaCa", "i_CaL", "d_fCa"],
                params=code_params,
            ),
            gotran.componentwise_derivative(ode, 15, params=code_params),
            gotran.linearized_derivatives(ode, params=code_params),
            jac,
            gotran.diagonal_jacobian_expressions(jac, params=code_params),
            gotran.jacobian_action_expressions(jac, params=code_params),
        ]

        return [codegen.function_code(comp) for comp in comps]
예제 #7
0
    def __init__(self, ode):
        # Create the domain
        domain = RectangleMesh(0, 0, L, L, N, N)

        # Load the cell model from file
        cellmodel = load_ode(ode)
        cellmodel = dolfin_jit(cellmodel, field_states, field_parameters)

        # Create the CardiacModel for the given domain and cell model
        heart = CardiacModel(domain, Constant(0.0), D, None, cellmodel)

        # Create the solver
        solver = GOSSplittingSolver(heart, self.GOSSparams())

        # Get the solution fields and subsolvers
        dolfin_solver = solver.ode_solver
        ode_solver = dolfin_solver._ode_system_solvers[0]

        self.solver = solver
        self.ode = ode
        self.dolfin_solver = dolfin_solver
        self.ode_solver = ode_solver
        self.domain = domain
예제 #8
0
    def __init__(self, ode):
        # Create the domain
        domain = RectangleMesh(0, 0, L, L, N, N)

        # Load the cell model from file
        cellmodel = load_ode(ode)
        cellmodel = dolfin_jit(cellmodel, field_states, field_parameters)

        # Create the CardiacModel for the given domain and cell model
        heart = CardiacModel(domain, Constant(0.0), D, None, cellmodel)

        # Create the solver
        solver = GOSSplittingSolver(heart, self.GOSSparams())
        
        # Get the solution fields and subsolvers
        dolfin_solver = solver.ode_solver
        ode_solver = dolfin_solver._ode_system_solvers[0]

        self.solver = solver
        self.ode = ode
        self.dolfin_solver = dolfin_solver
        self.ode_solver = ode_solver
        self.domain = domain
예제 #9
0
BZ = -numpy.ones(mesh.num_vertices())
edx = numpy.nonzero(M > 10)[0]  # The elements in SAN
idx = numpy.unique(E[edx, 0:3])  # The nodes in these elements
idx = idx.astype(int)
BZ[idx] = 1

pv = Viper(meshlist, BZ)
pv.interactive()

idx0 = np.nonzero(BZ >= 0)[0]  # SA cells
idx1 = np.nonzero(BZ < 0)[0]  # normal cells

m0 = len(idx0)
m1 = len(idx1)

ode0 = jit(load_ode("difrancesco"))
ode1 = jit(load_ode("myocyte.ode"))

solver0 = GRL2()  #ImplicitEuler()
solver1 = GRL2()  #ImplicitEuler()

system_solver0 = ODESystemSolver(m0, solver0, ode0)
system_solver1 = ODESystemSolver(m1, solver1, ode1)

P = make_parameter_field(m0, ode0, {'distance': BZ[idx0, :]})
#P = make_parameter_field(m0, ode0, {'distance': 1.})
system_solver0.set_field_parameters(P)

V0 = np.zeros(m0)
system_solver0.get_field_states(V0)
예제 #10
0
axis = [-80., 40.]
save_fig = True

plot_args = {
    'range_min': 0.,
    'range_max': 1.,
    'mode': 'color',
    #'window_width': 1366,
    #'window_height': 744
}

ode = 'FK_cAF'
domain = Mesh('atrial_mesh.xml')

# Load the cell model from file
cellmodel = load_ode(ode)
cellmodel = dolfin_jit(cellmodel, field_states, field_parameters)


def GOSSparams():
    params = GOSSplittingSolver.default_parameters()
    params["pde_solver"] = "monodomain"
    params["MonodomainSolver"]["linear_solver_type"] = "iterative"
    params["MonodomainSolver"]["theta"] = 1.0
    params["ode_solver"]["scheme"] = "RL1"
    params["apply_stimulus_current_to_pde"] = False
    return params


# Create the CardiacModel for the given domain and cell model
heart = CardiacModel(domain, Constant(0.0), D, None, cellmodel)
예제 #11
0
def pacing_cv(ode, BCL_range, D, dt, threshold=0, stim_amp=0, plot_args=None):
    D = Constant(D) # Must be adjusted with temporal and spatial resolution
    L = 20
    h = 0.5
    dt_low = 0.1

    # Create dolfin mesh
    N = int(L/h)
    domain = IntervalMesh(N, 0, L)

    # Load the cell model from file
    ode = load_ode(ode)
    cellmodel = dolfin_jit(ode, field_states=["V"], 
                            field_parameters=["stim_period", "stim_amplitude"])

    # Create the CardiacModel for the given mesh and cell model
    heart = CardiacModel(domain, Constant(0.0), D, None, cellmodel)

    # Create the solver
    solver = GOSSplittingSolver(heart, GOSSparams())

    # Get the solution fields and subsolvers
    dolfin_solver = solver.ode_solver
    ode_solver = dolfin_solver._ode_system_solvers[0]

    results = np.zeros((4, len(BCL_range)))

    for i, BCL in enumerate(BCL_range):
        # Set the stimulus parameter field
        stim_field = np.zeros(2*(N+1), dtype=np.float_)
        stim_field[0] = BCL
        stim_field[1] = stim_amp
        ode_solver.set_field_parameters(stim_field)

        # Pace 0D cell model and find quasi steady state 
        sspath = "../data/steadystates/"
        try:
            states = np.load(sspath+"steadystate_%s_BCL%d.npy" % (ode, BCL))
        except:
            print "Did not find steadystate for %s at BCL: %g, pacing 0D model" % (ode, BCL)
            find_steadystate(ODE, BCL, 0.01)
        
        # Load quasi steady state into 1D model
        (u, um) = solver.solution_fields()
        for node in range(N+1):
            ode_solver.states(node)[:] = states
            u.vector()[node] = states[0] 

        # Used for measuring cell activation
        xp = 0; yp=0
        results[0][i] = BCL
        print "BCL: %g" % BCL
        # Do 3 pulses
        for pulsenr in range(1,4):
            # Solve the pulse in higher temporal resolution
            for timestep, (u, vm) in solver.solve((i*BCL, i*BCL+50), dt):
                if plot_args: plot(u, **plot_args)
            
                x = u.vector()[20]
                y = u.vector()[35]

                if x > threshold and (x-threshold)*(xp-threshold) < 0:
                    t0 = (threshold - xp)/(x - xp)*dt + tp 
                if y > threshold and (y-threshold)*(yp-threshold) < 0:
                    t1 = (threshold - yp)/(y - yp)*dt + tp 
                    # Calculate conduction velocity
                    Cv = 15*h/(t1-t0)*1e3
                    print "\tCv: %g" % Cv
                    results[pulsenr][i] = Cv
                    t0 = 0; t1 = 0

                xp = x
                yp = y
                tp = timestep[0]

            # Wait for next pulse
            for timestep, (u, vm) in solver.solve((i*BCL+50, (i+1)*BCL), dt_low):
                if plot_args: plot(u, **plot_args)

    np.save("../data/results/cv_%s"%ode, results) 
	
	vertex_to_dof_map =  V.dofmap().vertex_to_dof_map(mesh)
	fenics_ordered_coordinates = mesh.coordinates()[vertex_to_dof_map]
	N_thread = fenics_ordered_coordinates.shape[0]
	BZ = find_leaf_nodes() ### p now contains the distances to the leaf nodes
	idx = BZ >= 0;

	celltype = np.ones(N_thread)
	celltype[idx] = 0;


	idx0 = np.nonzero(celltype==0)[0] # SA cells
	idx1 = np.nonzero(celltype==1)[0] # normal cells
	
	ode0 = jit(load_ode("difrancesco.ode"))
	ode1 = jit(load_ode("myocyte.ode"))

	solvermethod0 = ImplicitEuler()
	solvermethod1 = ImplicitEuler()

	m0 = len(idx0)
	m1 = len(idx1)
	N = m0+m1


	ode_solver0 = ODESystemSolver(m0, solvermethod0, ode0)
	ode_solver1 = ODESystemSolver(m1, solvermethod1, ode1)

	goss_wrap0 = Goss_wrapper(ode_solver0, advance0, V)
	goss_wrap1 = Goss_wrapper(ode_solver1, advance1, V)
def setup_model(cellmodel_strs, domain, space):

    L = domain.coordinates().max()
    family, degree = family_and_degree_from_str(space)
    V = FunctionSpace(domain, family, degree)

    Vs = FunctionSpace(domain, "P", 1)

    field_parameters = dict(rice_model_2008=["hf", "hb"],
                            hybrid=["TMon_coop", "TMon_pow"])
    field_states = dict(rice_model_2008=["active"],
                        hybrid=["active"])  #, "TCa"

    # Alter spatially varying paramters:
    param_scale = Expression("offset+scale*exp(-((x[0]-center_x)*(x[0]-center_x)+"\
                               "(x[1]-center_y)*(x[1]-center_y))/(sigma*sigma))",
                               center_x=3*L/4, center_y=L/4, offset=0.0, sigma=L/2, \
                               scale=1.0)

    cellmodels = [dolfin_jit(\
        load_ode(cellmodel), field_states=field_states[cellmodel], \
        field_parameters=field_parameters[cellmodel]) for cellmodel in cellmodel_strs]

    if "rice_model_2008" in cellmodel_strs:
        max_value = 0.14
        index = cellmodel_strs.index("rice_model_2008")
        model = cellmodels[index]

        values = dict(hf=0.03, hb=0.06)
        y_shift = dict(hf=2.5, hb=3)
        for param in ["hf", "hb"]:
            p0 = model.get_parameter(param)
            param_scale.offset = p0
            param_scale.scale = -(p0 - values[param])
            param_scale.center_y = y_shift[param] * L / 4
            p_func = Function(V, name=param)
            p_func.interpolate(param_scale)
            model.set_parameter(param, p_func)

    if "hybrid" in cellmodel_strs:
        max_value = 0.14
        index = cellmodel_strs.index("hybrid")
        model = cellmodels[index]
        values = dict(TMon_coop=0.5, TMon_pow=0.5)
        y_shift = dict(TMon_coop=1.5, TMon_pow=2.0)
        for param in values.keys():
            p0 = model.get_parameter(param)
            param_scale.offset = p0
            param_scale.scale = -(p0 - values[param])
            param_scale.center_y = y_shift[param] * L / 4
            p_func = Function(V, name=param)
            p_func.interpolate(param_scale)
            model.set_parameter(param, p_func)

    if len(cellmodels) == 2:
        subdomain = CompiledSubDomain("x[1] <= 0.5")
        if space == "P_1":
            cellmodel_domains = VertexFunction("size_t", domain, 10)
        else:
            cellmodel_domains = CellFunction("size_t", domain, 10)
        subdomain.mark(cellmodel_domains, 20)
        cellmodels = dict(label_model
                          for label_model in zip([10, 20], cellmodels))

    elif len(cellmodels) == 1:
        cellmodels = cellmodels[0]
        cellmodel_domains = None

    else:
        assert False

    params = DOLFINODESystemSolver.default_parameters()
    params["solver"] = "RL1"
    solver = DOLFINODESystemSolver(domain, cellmodels, domains=cellmodel_domains, \
                                   space=space, params=params)
    u = Function(solver.state_space)
    solver.from_field_states(u)

    if space != "P1":
        u_plot = Function(Vs)
        if solver.num_field_states > 1:
            u_plot.assign(project(u[0], Vs))
        else:
            u_plot.assign(project(u, Vs))

    elif solver.num_field_states > 1:
        u_plot = u.split(True)[0]
    else:
        u_plot = u

    t = 0
    dt = .1
    tstop = 300

    while t < tstop:
        solver.step((t, t + dt), u)
        if (t % 10.) < dt:
            print "t:", t
            if space != "P1":
                if solver.num_field_states > 1:
                    u_plot.assign(project(u[0], Vs))
                else:
                    u_plot.assign(project(u, Vs))

            elif solver.num_field_states > 1:
                assign(u_plot, u.sub(0))

            plot(u_plot, scale=0., range_max=max_value, range_min=0.)

        t += dt

    plot(u_plot, scale=0., range_max=max_value, range_min=0., interactive=True)
예제 #14
0
#import os
from gotran import load_ode
from goss import *
import numpy as np

oscilator = jit(load_ode("oscilator"))

#solver1 = ExplicitEuler(oscilator)
solver1 = RL1(oscilator)

tstop = 10.
n_steps = 1000
time = np.linspace(0, tstop, n_steps + 1)

dt = time[1] - time[0]
u0 = np.array([1.0, 0.0])

u_exact = np.array([np.cos(time), np.sin(time)])
u1 = np.zeros_like(u_exact)
u2 = np.zeros_like(u_exact)
u1[:, 0] = u0
u2[:, 0] = u0

print(u1)
u = u0
for step in range(1, n_steps + 1):
    #u = u1[step]
    t = time[step]
    solver1.forward(u, t, dt)
    u1[:, step] = u
예제 #15
0
class TestODESystemSolver(object):

    ode = jit(load_ode("tentusscher_2004_mcell"),
              field_states=["V", "Ca_i"],
              field_parameters=["g_CaL", "K_o"])

    def test_ode_interface(self):
        assert self.ode.num_field_states() == 2
        assert self.ode.num_field_parameters() == 2
        assert self.ode.num_states() == 17
        assert self.ode.num_parameters() == 45

        Cm = self.ode.get_parameter("Cm")
        assert Cm == 0.185

        self.ode.set_parameter("Cm", 0.2)
        assert self.ode.get_parameter("Cm") == 0.2
        self.ode.set_parameter("Cm", 0.185)

        assert isinstance(self.ode.get_state_names(), list)
        assert isinstance(self.ode.get_parameter_names(), list)
        assert len(self.ode.get_state_names()) == 17
        assert len(self.ode.get_parameter_names()) == 45

        assert "stim_start" in self.ode.get_parameter_names()
        state_names = [
            "Xr1", "Xr2", "Xs", "m", "h", "j", "d", "f", "fCa", "s", "r", "g",
            "Ca_i", "Ca_SR", "Na_i", "V", "K_i"
        ]
        state_names.sort()
        ode_state_names = self.ode.get_state_names()
        ode_state_names.sort()
        assert ode_state_names == state_names

        with pytest.raises(RuntimeError):
            self.ode.set_parameter("cm", 0.185)

        with pytest.raises(RuntimeError):
            self.ode.get_parameter("JADA")

    def test_ode_system_interface(self):

        num_nodes = 100
        solver = RL1()
        system = ODESystemSolver(num_nodes, solver, self.ode)
        system.reset_default()
        field_states = np.zeros(self.ode.num_field_states() * num_nodes)
        field_parameters = np.zeros(self.ode.num_field_parameters() *
                                    num_nodes)

        assert isinstance(system.states(), np.ndarray)
        assert len(system.states()) == num_nodes * self.ode.num_states()

        # Check tangled access
        system.get_field_states(field_states, True)
        assert sum(field_states[::2] == -86.2) == num_nodes
        assert sum(field_states[1::2] == 0.0002) == num_nodes

        # Check untangled access
        system.get_field_states(field_states, False)
        assert sum(field_states[:num_nodes] == -86.2) == num_nodes
        assert sum(field_states[num_nodes:] == 0.0002) == num_nodes

        assert system.num_nodes() == num_nodes

        with pytest.raises(ValueError):
            system.set_field_states(
                np.zeros(self.ode.num_field_states() * num_nodes // 2))

        with pytest.raises(TypeError):
            system.set_field_states(
                np.zeros(self.ode.num_field_states() * num_nodes, dtype=int))

        with pytest.raises(ValueError):
            system.set_field_parameters(
                np.zeros(self.ode.num_field_states() * num_nodes // 2))

        with pytest.raises(TypeError):
            system.set_field_parameters(np.zeros(self.ode.num_field_states()*num_nodes, \
                                                 dtype=int))
예제 #16
0
field_parameters = ["stim_amplitude", "stim_offset", "stim_period"]
axis = [-80., 40.]
save_fig = True

plot_args = {'range_min': 0.,
             'range_max': 1.,
             'mode': 'color',
             #'window_width': 1366,
             #'window_height': 744
             }

ode = 'FK_cAF'
domain = Mesh('atrial_mesh.xml')

# Load the cell model from file
cellmodel = load_ode(ode)
cellmodel = dolfin_jit(cellmodel, field_states, field_parameters)

def GOSSparams():
    params = GOSSplittingSolver.default_parameters()
    params["pde_solver"] = "monodomain"
    params["MonodomainSolver"]["linear_solver_type"] = "iterative"
    params["MonodomainSolver"]["theta"] = 1.0
    params["ode_solver"]["scheme"] = "RL1"
    params["apply_stimulus_current_to_pde"] = False
    return params

# Create the CardiacModel for the given domain and cell model
heart = CardiacModel(domain, Constant(0.0), D, None, cellmodel)

# Create the solver
예제 #17
0
def pacing_cv(ode, BCL_range, D, dt, threshold=0, stim_amp=0, plot_args=None):
    D = Constant(D)  # Must be adjusted with temporal and spatial resolution
    L = 20
    h = 0.5
    dt_low = 0.1

    # Create dolfin mesh
    N = int(L / h)
    domain = IntervalMesh(N, 0, L)

    # Load the cell model from file
    ode = load_ode(ode)
    cellmodel = dolfin_jit(ode,
                           field_states=["V"],
                           field_parameters=["stim_period", "stim_amplitude"])

    # Create the CardiacModel for the given mesh and cell model
    heart = CardiacModel(domain, Constant(0.0), D, None, cellmodel)

    # Create the solver
    solver = GOSSplittingSolver(heart, GOSSparams())

    # Get the solution fields and subsolvers
    dolfin_solver = solver.ode_solver
    ode_solver = dolfin_solver._ode_system_solvers[0]

    results = np.zeros((4, len(BCL_range)))

    for i, BCL in enumerate(BCL_range):
        # Set the stimulus parameter field
        stim_field = np.zeros(2 * (N + 1), dtype=np.float_)
        stim_field[0] = BCL
        stim_field[1] = stim_amp
        ode_solver.set_field_parameters(stim_field)

        # Pace 0D cell model and find quasi steady state
        sspath = "../data/steadystates/"
        try:
            states = np.load(sspath + "steadystate_%s_BCL%d.npy" % (ode, BCL))
        except:
            print "Did not find steadystate for %s at BCL: %g, pacing 0D model" % (
                ode, BCL)
            find_steadystate(ODE, BCL, 0.01)

        # Load quasi steady state into 1D model
        (u, um) = solver.solution_fields()
        for node in range(N + 1):
            ode_solver.states(node)[:] = states
            u.vector()[node] = states[0]

        # Used for measuring cell activation
        xp = 0
        yp = 0
        results[0][i] = BCL
        print "BCL: %g" % BCL
        # Do 3 pulses
        for pulsenr in range(1, 4):
            # Solve the pulse in higher temporal resolution
            for timestep, (u, vm) in solver.solve((i * BCL, i * BCL + 50), dt):
                if plot_args: plot(u, **plot_args)

                x = u.vector()[20]
                y = u.vector()[35]

                if x > threshold and (x - threshold) * (xp - threshold) < 0:
                    t0 = (threshold - xp) / (x - xp) * dt + tp
                if y > threshold and (y - threshold) * (yp - threshold) < 0:
                    t1 = (threshold - yp) / (y - yp) * dt + tp
                    # Calculate conduction velocity
                    Cv = 15 * h / (t1 - t0) * 1e3
                    print "\tCv: %g" % Cv
                    results[pulsenr][i] = Cv
                    t0 = 0
                    t1 = 0

                xp = x
                yp = y
                tp = timestep[0]

            # Wait for next pulse
            for timestep, (u, vm) in solver.solve(
                (i * BCL + 50, (i + 1) * BCL), dt_low):
                if plot_args: plot(u, **plot_args)

    np.save("../data/results/cv_%s" % ode, results)
예제 #18
0
from gotran import CCodeGenerator
from gotran import load_ode
from gotran import ODERepresentation

drv.init()
dev = drv.Device(0)

dev = drv.Device(0)
arch = "sm_%d%d" % dev.compute_capability()

optimisations = dict()

filename = "tentusscher_panfilov_2006_M_cell_continuous"

ode = load_ode(filename)

# Get num states and parameters which sets the offset into the state
# and parameter array

num_states = ode.num_states
num_params = ode.num_parameters

oderepr = ODERepresentation(ode, **optimisations)
ccode = CCodeGenerator(oderepr)

init_state_code = (ccode.init_states_code().replace(
    "void", "__global__ void").replace(
        "{",
        "{\n  const int thread_ind = blockIdx.x*blockDim.x + threadIdx.x;"
        "\n  const int offset = thread_ind*%d;" % num_states,
예제 #19
0
def pacing_cv(ode, BCL_range, D, L, h, dt):
    """
	Pace a 0D cell model for 50 cycles at given BCL, then
	simulate a wave in a 1D strand. 5 beats are initiated 
	at the left hand side of the strand, after the fifth 
	beat, the conduction velocity is measured.
	"""

    # Create domain
    N = int(L / h)  # number of nodes - 1
    domain = IntervalMesh(N, 0, L)

    # Define functions
    V = FunctionSpace(domain, 'Lagrange', 1)
    u = TrialFunction(V)
    up = Function(V)
    v = TestFunction(V)

    # Assemble the mass matrix and the stiffness matrix
    M = assemble(u * v * dx)
    K = assemble(Constant(dt) * inner(D * grad(u), grad(v)) * dx)
    A = M + K
    pde_solver = LUSolver(A)
    pde_solver.parameters["reuse_factorization"] = True

    # Set up ODESolver
    ode = load_ode(ode)
    compiled_ode = dolfin_jit(ode,
                              field_states=["V"],
                              field_parameters=["stim_period"])

    params = DOLFINODESystemSolver.default_parameters()
    params["scheme"] = "RL1"
    stim_field = interpolate(Expression("near(x[0],0)*sd", sd=1.), V)
    #compiled_ode.set_parameter("stim_duration", stim_field)
    BCL = 1000
    #compiled_ode.set_parameter("stim_period", BCL)

    ode_solver = DOLFINODESystemSolver(domain, compiled_ode, params=params)
    solver = ode_solver._ode_system_solvers[0]

    try:
        states = np.load("../data/steadystate_%s_BCL%d.npy" % (ode, BCL))
    except:
        states = find_steadystate(BCL, 50, dt, ode, plot_results=False)

    for i in range(N + 1):
        solver.states(i)[:] = states
        up.vector()[i] = solver.states(i)[0]

    #params = np.zeros((N+1)*2)
    solver.set_field_parameters(np.array([BCL] + [0] * N, dtype=np.float_))
    #stim_field = interpolate(Expression("0"), V)
    #compiled_ode.set_parameter("stim_duration", stim_field)

    t = 0.
    tstop = 1e6
    while t < tstop:
        # Step ODE solver
        ode_solver.step((t, t + dt), up)

        # Assemble RHS and solve pde
        b = M * up.vector()
        pde_solver.solve(up.vector(), b)

        # Plot solution
        plot(up, range_min=-80.0, range_max=40.0)

        t += dt
예제 #20
0
from gotran import load_ode
from goss import *
ode = jit(load_ode('myocyte.ode'))


def f(x):
    return x


class Tull:
    def __init__(self):
        self.a = 2


a = f

b = Tull()
print isinstance(b, Tull)
예제 #21
0
D = as_tensor([[Dx, 0.], [0, Dy]])

dt = 0.125
tstop = 25.0
a = Constant(1.)
V_init = -85.
V_amp = 85.
t = Constant(0.)

# Domain and solution space
do_plot = False
L = 100.
N = 1024
#N = 128
domain = RectangleMesh(-L, -L, L, L, N, N)
cellmodel = dolfin_jit(load_ode("tentusscher_panfilov_2006_M_cell.ode"),
                       field_states=["V"])
heart = CardiacModel(domain, t, D, None, cellmodel)
ps = GOSSplittingSolver.default_parameters()
ps["pde_solver"] = "monodomain"

ps["MonodomainSolver"]["linear_solver_type"] = "iterative"
ps["MonodomainSolver"]["theta"] = 1.0
ps["ode_solver"]["solver"] = "RL1"
ps["ode_solver"]["num_threads"] = 8 / MPI.size(domain.mpi_comm())

# If cuda
ps["ode_solver"]["use_cuda"] = True
ps["ode_solver"]["cuda_params"]["float_precision"] = "double"
ps["ode_solver"]["cuda_params"]["solver"] = "rush_larsen"
예제 #22
0
def test_subode():
    ode_from_file = gotran.load_ode(
        _here.joinpath("tentusscher_2004_mcell_updated"))

    ode = gotran.ODE("Tentusscher_2004_merged")

    # Add parameters and states
    ode.add_parameters(
        Na_i=ScalarParam(11.6),
        Na_o=ScalarParam(140),
        K_i=ScalarParam(138.3),
        K_o=ScalarParam(5.4),
        Ca_o=ScalarParam(2),
        Ca_i=ScalarParam(0.0002),
    )

    mem = ode("Membrane")
    mem.add_states(V=ScalarParam(-86.2))

    mem.add_parameters(
        Cm=ScalarParam(0.185),
        F=ScalarParam(96485.3415),
        R=ScalarParam(8314.472),
        T=ScalarParam(310),
        V_c=ScalarParam(0.016404),
        stim_amplitude=ScalarParam(0),
        stim_duration=ScalarParam(1),
        stim_period=ScalarParam(1000),
        stim_start=ScalarParam(1),
    )

    rev_pot = ode("Reversal potentials")
    rev_pot.add_parameter("P_kna", ScalarParam(0.03))

    # Add intermediates
    ode.i_Stim = (
        -mem.stim_amplitude * (1 - 1 /
                               (1 + exp(5.0 * ode.t - 5.0 * mem.stim_start))) /
        (1 +
         exp(5.0 * ode.t - 5.0 * mem.stim_start - 5.0 * mem.stim_duration)))

    rev_pot.E_Na = mem.R * mem.T * log(ode.Na_o / ode.Na_i) / mem.F
    rev_pot.E_K = mem.R * mem.T * log(ode.K_o / ode.K_i) / mem.F
    rev_pot.E_Ks = (mem.R * mem.T * log(
        (ode.Na_o * rev_pot.P_kna + ode.K_o) /
        (ode.K_i + ode.Na_i * rev_pot.P_kna), ) / mem.F)
    rev_pot.E_Ca = 0.5 * mem.R * mem.T * log(ode.Ca_o / ode.Ca_i) / mem.F

    # Get the E_Na expression and one dependency
    E_Na = ode.present_ode_objects["E_Na"]
    Na_i = ode.present_ode_objects["Na_i"]

    assert isinstance(Na_i, gotran.Parameter)

    # Check dependencies
    assert Na_i in ode.expression_dependencies[E_Na]
    assert E_Na in ode.object_used_in[Na_i]

    ode.import_ode(_here.joinpath("Sodium"))

    # Check dependencies after sub ode has been loaded
    E_Na = ode.present_ode_objects["E_Na"]
    Na_i_new = ode.present_ode_objects["Na_i"]
    assert isinstance(Na_i_new, gotran.State)

    assert Na_i not in ode.expression_dependencies[E_Na]
    assert Na_i_new in ode.expression_dependencies[E_Na]
    assert E_Na not in ode.object_used_in[Na_i]
    assert E_Na in ode.object_used_in[Na_i_new]

    pot = gotran.load_ode(_here.joinpath("Potassium"))
    ode.import_ode(pot, prefix="pot")

    pot_comps = [comp.name for comp in pot.components]

    # Add sub ode by extracting components
    ode.import_ode(
        ode_from_file,
        components=[
            "Calcium dynamics",
            "Calcium background current",
            "Calcium pump current",
            "L type ca current",
        ],
    )

    # Get all the Potassium currents
    i_K1 = ode("Inward rectifier potassium current").pot_i_K1
    i_Kr = ode("Rapid time dependent potassium current").pot_i_Kr
    i_Ks = ode("Slow time dependent potassium current").pot_i_Ks
    i_to = ode("Transient outward current").pot_i_to
    i_p_K = ode("Potassium pump current").pot_i_p_K

    # Get all the Sodium currents
    i_Na = ode("Fast sodium current").i_Na
    i_b_Na = ode("Sodium background current").i_b_Na
    i_NaCa = ode("Sodium calcium exchanger current").i_NaCa
    i_NaK = ode("Sodium potassium pump current").i_NaK

    # Get all the Calcium currents
    i_CaL = ode("L type ca current").i_CaL
    i_b_Ca = ode("Calcium background current").i_b_Ca
    i_p_Ca = ode("Calcium pump current").i_p_Ca

    # Membrane potential derivative
    mem.dV_dt = (-i_Ks - i_p_K - i_Na - i_K1 - i_p_Ca - i_b_Ca - i_NaK -
                 i_CaL - i_Kr - ode.i_Stim - i_NaCa - i_b_Na - i_to)

    # Finalize ODE
    ode.finalize()

    for name, obj in list(ode.present_ode_objects.items()):

        # If object in prefixed potassium components
        if ode.object_component[obj].name in pot_comps:
            loaded_obj = ode_from_file.present_ode_objects[name.replace(
                "pot_", "")]
        else:
            loaded_obj = ode_from_file.present_ode_objects[name]

        assert type(obj) == type(loaded_obj)
        assert loaded_obj.param.value == pytest.approx(obj.param.value)
예제 #23
0
def spiral(ode,
           domain,
           D,
           S1,
           S2,
           plot_args={},
           save_fig=False,
           threshold=0,
           BCL=500):
    # Load the cell model from file
    cellmodel = load_ode(ode)
    cellmodel = dolfin_jit(cellmodel, field_states, field_parameters)

    # Create the CardiacModel for the given domain and cell model
    heart = CardiacModel(domain, Constant(0.0), D, None, cellmodel)

    # Create the solver
    solver = GOSSplittingSolver(heart, GOSSparams())
    dolfin_solver = solver.ode_solver
    ode_solver = dolfin_solver._ode_system_solvers[0]

    # Calculate parameter fields
    V = VectorFunctionSpace(domain, 'CG', 1)
    S1 = interpolate(S1, V).vector().array()
    S2 = interpolate(S2, V).vector().array()
    nostim = np.zeros(S1.shape, dtype=np.float_)

    # Read in steady state from file
    try:
        sspath = "../data/steadystates/"
        steadystate = np.load(sspath + "steadystate_%s_BCL%d.npy" % (ode, BCL))
    except:
        print "Did not find steadystate for %s at BCL: %g, pacing 0D model" % (
            ode, BCL)
        steadystate = find_steadystate(ODE, BCL, 0.01)

    # Load steadystate into 2D model
    (u, um) = solver.solution_fields()
    for node in range(domain.coordinates().shape[0]):
        ode_solver.states(node)[:] = steadystate
        u.vector()[node] = steadystate[0]

    cnt = 0
    # Stimulate S1 pulse
    plot(u, interactive=True, **plot_args)
    ode_solver.set_field_parameters(S1)
    for timestep, (u, vm) in solver.solve((0, S2_time), dt):
        fig = plot(u, interactive=False, **plot_args)

        if save_fig and cnt % 5 == 0:
            padded_index = '%08d' % cnt
            fig.write_png('../tmp/randfib2_spiral_%s' % ode + padded_index)
        cnt += 1

    ode_solver.set_field_parameters(S2)

    # Stimulate S2 pulse
    for timestep, (u, vm) in solver.solve((S2_time, S2_time + 100), dt):
        fig = plot(u, interactive=False, **plot_args)

        if save_fig and cnt % 5 == 0:
            padded_index = '%08d' % cnt
            fig.write_png('../tmp/randfib2_spiral_%s' % ode + padded_index)
        cnt += 1

    # Turn of all stimulus and iterate until simulation stop
    # We monitor the top right corner to find the randfib_spiral period
    ode_solver.set_field_parameters(nostim)
    trig_time = []
    xp = 0
    for timestep, (u, vm) in solver.solve((S2_time + 100, tstop), dt):
        fig = plot(u, interactive=False, **plot_args)

        x = u.vector()[-1]
        if x > threshold and (x - threshold) * (xp - threshold) < 0:
            trig_time.append(timestep[0])
            print "Trigger: %g" % timestep[0]
        xp = x

        if save_fig and cnt % 5 == 0:
            padded_index = '%08d' % cnt
            fig.write_png('../tmp/randfib2_spiral_%s' % ode + padded_index)
        cnt += 1
예제 #24
0
class TestODESolvers(object):
    orders = dict(RK4=4,
                  RK2=2,
                  RL1=1,
                  RL2=2,
                  GRL1=1,
                  GRL2=2,
                  ExplicitEuler=1,
                  ThetaSolver=2,
                  BasicImplicitEuler=1,
                  RKF32=3,
                  ESDIRK23a=3,
                  ESDIRK4O32=4,
                  ImplicitEuler=1)
    oscilator = jit(load_ode("oscilator"))
    exclude_osc = ["ESDIRK4O32", "ESDIRK23a"]
    exclude_tent = ["RKF32", "ESDIRK4O32", "ESDIRK23a", "BasicImplicitEuler"]
    dir_path = os.path.dirname(__file__)
    #Vm_reference = np.fromfile(os.path.join(dir_path, "Vm_reference.npy"))
    tentusscher = jit(load_ode("tentusscher_2004_mcell"))

    @parametrize(("solver_str"), goss_solvers)
    def test_convergence_order(self, solver_str):

        if solver_str in self.exclude_osc:
            return

        solver = eval(solver_str)(self.oscilator)

        tstop = 10.
        exact = np.array([np.cos(tstop), np.sin(tstop)])
        errors = []
        for dt in [0.05, 0.025, 0.0125, 0.00625]:
            u = np.array([1.0, 0.])
            t = 0.0
            nsteps = int(tstop / dt)
            for step in range(nsteps):
                solver.forward(u, t, dt)
                t += dt

            errors.append(np.sqrt(np.sum(((exact - u) / exact)**2)))

        assert min(convergence_order(errors)) >= self.orders[solver_str] - 0.1

    @parametrize(("solver_str"), goss_solvers)
    def test_long_run(self, solver_str):

        if solver_str in self.exclude_tent:
            return

        dt = 0.0002
        tstop = 10
        ind_V = self.tentusscher.get_state_names().index("V")
        dt_ref = 0.1

        solver = eval(solver_str)(self.tentusscher)

        self.tentusscher.set_parameter("stim_amplitude", 52.)
        self.tentusscher.set_parameter("stim_start", 0.)

        u = self.tentusscher.init_state_values()

        t = 0.0
        nsteps = int(tstop / dt)
        for step in range(nsteps):
            solver.forward(u, t, dt)
            t += dt

        # Test against run with scipy integrate
        assert abs(u[ind_V] - 12.948) < 1e-3
예제 #25
0
def _test_codegeneration(
    module,
    body_repr=default_params["code"]["body"]["representation"],
    body_optimize="none",
    param_repr="named",
    state_repr="named",
    use_cse=False,
    use_enum=False,
    float_precision="double",
):
    parameters_name = "parameters"
    states_name = "states"
    body_array_name = "body"
    body_in_arg = False
    # The test that will be attached to the TestCase class below

    states_values = module.init_state_values()
    parameter_values = module.init_parameter_values()
    rhs_ref_values = module.rhs(states_values, 0.0, parameter_values)
    jac_ref_values = module.compute_jacobian(states_values, 0.0, parameter_values)

    gen_params = parameters["generation"].copy()

    # Update code_params
    code_params = gen_params["code"]
    code_params["body"]["optimize_exprs"] = body_optimize
    code_params["body"]["representation"] = body_repr
    code_params["body"]["in_signature"] = body_in_arg
    code_params["body"]["array_name"] = body_array_name
    code_params["parameters"]["representation"] = param_repr
    code_params["parameters"]["array_name"] = parameters_name
    code_params["states"]["representation"] = state_repr
    code_params["states"]["array_name"] = states_name
    code_params["body"]["use_cse"] = use_cse
    code_params["body"]["use_enum"] = use_enum
    code_params["float_precision"] = float_precision
    code_params["default_arguments"] = "stp"

    # Reload ODE for each test
    ode = gotran.load_ode(_here.joinpath("tentusscher_2004_mcell_updated.ode"))
    codegen = gotran.PythonCodeGenerator(gen_params)
    rhs_comp = gotran.rhs_expressions(ode, params=code_params)
    rhs_code = codegen.function_code(rhs_comp)

    rhs_namespace = {}
    exec(rhs_code, rhs_namespace)

    args = [states_values, 0.0]
    if param_repr != "numerals":
        args.append(parameter_values)

    if body_in_arg:
        body = np.zeros(rhs_comp.shapes[body_array_name])
        args.append(body)

    # Call the generated rhs function
    rhs_values = rhs_namespace["rhs"](*args)

    rhs_norm = np.sqrt(np.sum(rhs_ref_values - rhs_values) ** 2)

    eps = 1e-8 if float_precision == "double" else 1e-6
    assert rhs_norm < eps

    # Only evaluate jacobian if using full body_optimization and body repr is reused_array
    if (
        body_optimize != "numerals_symbols"
        and body_repr != "reused_array"
        and param_repr == "named"
    ):
        return

    jac_comp = gotran.jacobian_expressions(ode, params=code_params)
    jac_code = codegen.function_code(jac_comp)

    jac_namespace = {}
    exec(jac_code, jac_namespace)

    args = [states_values, 0.0]
    if param_repr != "numerals":
        args.append(parameter_values)

    if body_in_arg:
        body = np.zeros(jac_comp.shapes[body_array_name])
        args.append(body)

    jac_values = jac_namespace["compute_jacobian"](*args)
    jac_norm = np.sqrt(np.sum(jac_ref_values - jac_values) ** 2)

    eps = 1e-8 if float_precision == "double" else 1e-3
    assert jac_norm < eps
예제 #26
0
    N = (x_nodes + 1) * (y_nodes + 1)
    T = 100
    dt = 0.5
    t = 0
    time_steps = int((T - t) / dt)
    time_solution_method = 'BE'  ### crank nico

    save = False  #save solutions as binary
    savemovie = False  #create movie from results. Takes time!
    plot_realtime = True

    # small hack
    mesh = UnitSquareMesh(x_nodes, y_nodes)
    space = FunctionSpace(mesh, 'Lagrange', 1)
    ### Setting up Goss/Gotran part
    ode = jit(load_ode("myocyte.ode"))
    vertex_to_dof_map = space.dofmap().vertex_to_dof_map(mesh)
    N_thread = mesh.coordinates().shape[0]
    print vertex_to_dof_map.shape, mesh.coordinates().shape

    ist = np.zeros(N_thread, dtype=np.float_)
    ist = stimulation_domain(mesh.coordinates(), amp=-10)

    P0 = make_parameter_field(mesh.coordinates(), ode)
    P1 = make_parameter_field(mesh.coordinates(), ode, ist=ist)

    ind_stim = P1[:, 1] != 0
    print "P0", P0[P0[:, 1] != 0., 1]
    print "P1", P1[ind_stim, 1]

    solver = ThetaSolver()
예제 #27
0
def test_creation():

    # Adding a phoney ODE
    ode = gotran.ODE("test")

    # Add states and parameters
    j = ode.add_state("j", 1.0)
    i = ode.add_state("i", 2.0)
    k = ode.add_state("k", 3.0)

    ii = ode.add_parameter("ii", 0.0)
    jj = ode.add_parameter("jj", 0.0)
    kk = ode.add_parameter("kk", 0.0)

    # Try overwriting state
    with pytest.raises(gotran.GotranException):
        ode.add_parameter("j", 1.0)

    # Try overwriting parameter
    with pytest.raises(gotran.GotranException):
        ode.add_state("ii", 1.0)

    assert ode.num_states == 3
    assert ode.num_parameters == 3
    assert ode.present_component == ode

    # Add an Expression
    ode.alpha = i * j

    # Add derivatives for all states in the main component
    ode.add_comment("Some nice derivatives and an algebraic expression")
    ode.di_dt = ode.alpha + ii
    ode.dj_dt = -ode.alpha - jj
    ode.alg_k_0 = kk * k * ode.alpha

    assert ode.num_intermediates == 1

    # Add a component with 2 states
    ode("jada").add_states(m=2.0, n=3.0, l=1.0, o=4.0)
    ode("jada").add_parameters(ll=1.0, mm=2.0)

    # Define a state derivative
    ode("jada").dm_dt = ode("jada").ll - (ode("jada").m - ode.i)

    jada = ode("jada")
    assert ode.present_component == jada

    # Test num_foo
    assert jada.num_states == 4
    assert jada.num_parameters == 2
    assert ode.num_states == 7
    assert ode.num_parameters == 5
    assert ode.num_components == 2
    assert jada.num_components == 1

    # Add expressions to the component
    jada.tmp = jada.ll * jada.m**2 + 3 / i - ii * jj
    jada.tmp2 = ode.j * exp(jada.tmp)

    # Reduce state n
    jada.add_solve_state(jada.n, 1 - jada.l - jada.m - jada.n)

    assert ode.num_intermediates == 4

    # Try overwriting parameter with expression
    with pytest.raises(gotran.GotranException):
        jada.ll = jada.tmp * jada.tmp2

    # Create a derivative expression
    ode.add_comment("More funky objects")
    jada.tmp3 = jada.tmp2.diff(ode.t) + jada.n + jada.o
    jada.add_derivative(jada.l, ode.t, jada.tmp3)
    jada.add_algebraic(jada.o, jada.o**2 - exp(jada.o) + 2 / jada.o)

    assert ode.num_intermediates == 9
    assert ode.num_state_expressions == 6
    assert ode.is_complete
    assert ode.num_full_states == 6

    # Try adding expressions to ode component
    with pytest.raises(gotran.GotranException):
        ode.p = 1.0

    # Check used in and dependencies for one intermediate
    tmp3 = ode.present_ode_objects["tmp3"]
    assert ode.object_used_in[tmp3] == {ode.present_ode_objects["dl_dt"]}

    for sym in symbols_from_expr(tmp3.expr, include_derivatives=True):
        assert (ode.present_ode_objects[sympycode(sym)]
                in ode.expression_dependencies[tmp3])

    # Add another component to test rates
    bada = ode("bada")
    bada.add_parameters(nn=5.0, oo=3.0, qq=1.0, pp=2.0)

    nada = bada.add_component("nada")
    nada.add_states(("r", 3.0), ("s", 4.0), ("q", 1.0), ("p", 2.0))
    assert bada.num_parameters == 4
    assert bada.num_states == 4
    nada.p = 1 - nada.r - nada.s - nada.q

    assert "".join(p.name for p in ode.parameters) == "iijjkkllmmnnooppqq"
    assert "".join(s.name for s in ode.states) == "jiklmnorsqp"
    assert not ode.is_complete

    # Add rates to component making it a Markov model component
    nada.rates[nada.r, nada.s] = 3 * exp(-i)

    # Try add a state derivative to Markov model
    with pytest.raises(gotran.GotranException):
        nada.ds_dt = 3.0

    nada.rates[nada.s, nada.r] = 2.0
    nada.rates[nada.s, nada.q] = 2.0
    nada.rates[nada.q, nada.s] = 2 * exp(-i)
    nada.rates[nada.q, nada.p] = 3.0
    nada.rates[nada.p, nada.q] = 4.0

    assert ode.present_component == nada

    markov = bada.add_component("markov_2")
    markov.add_states(("tt", 3.0), ("u", 4.0), ("v", 1.0))

    with pytest.raises(gotran.GotranException):
        markov.rates[nada.s, nada.r] = 2.0

    with pytest.raises(gotran.GotranException):
        markov.rates[markov.tt, markov.u] = 2 * exp(markov.u)

    with pytest.raises(gotran.GotranException):
        markov.rates[[markov.tt, markov.u, markov.v]] = 5.0

    with pytest.raises(gotran.GotranException):
        markov.rates[[markov.tt, markov.u,
                      markov.v]] = Matrix([[1, 2 * i, 0.0], [0.0, 2.0, 4.0]], )
    with pytest.raises(gotran.GotranException):
        markov.rates[markov.tt, markov.tt] = 5.0

    markov.rates[[markov.tt, markov.u, markov.v]] = Matrix(
        [[0.0, 2 * i, 2.0], [4.0, 0.0, 2.0], [5.0, 2.0, 0.0]], )

    ode.finalize()
    assert ode.is_complete
    assert ode.is_dae

    # Test Mass matrix
    vector = ode.mass_matrix * Matrix([1] * ode.num_full_states)
    assert (0, 0) == (vector[2], vector[5])
    assert sum(ode.mass_matrix) == ode.num_full_states - 2
    assert sum(vector) == ode.num_full_states - 2

    assert "".join(s.name for s in ode.full_states) == "ijkmlorsqttuv"
    assert ode.present_component == ode

    # Test saving
    ode.save("test_ode")

    # Test loading
    ode_loaded = gotran.load_ode("test_ode")

    # Clean
    os.unlink("test_ode.ode")

    # Test same signature
    # self.assertEqual(ode.signature(), ode_loaded.signature())

    # Check that all objects are the same and evaluates to same value
    for name, obj in list(ode.present_ode_objects.items()):
        loaded_obj = ode_loaded.present_ode_objects[name]
        assert type(obj) == type(loaded_obj)
        assert loaded_obj.param.value == pytest.approx(obj.param.value)
예제 #28
0
	N = (x_nodes+1)*(y_nodes+1)
	T = 100
	dt = 0.5
	t = 0
	time_steps = int((T-t)/dt)
	time_solution_method = 'BE' ### crank nico

	save = False #save solutions as binary
	savemovie = False #create movie from results. Takes time! 
	plot_realtime = True

	# small hack
	mesh = UnitSquareMesh(x_nodes, y_nodes) 
	space = FunctionSpace(mesh, 'Lagrange', 1)
	### Setting up Goss/Gotran part
	ode = jit(load_ode("myocyte.ode"))
	vertex_to_dof_map =  space.dofmap().vertex_to_dof_map(mesh)
	N_thread = mesh.coordinates().shape[0]
	print vertex_to_dof_map.shape, mesh.coordinates().shape
	   
	ist = np.zeros(N_thread, dtype=np.float_) 
	ist = stimulation_domain(mesh.coordinates(), amp= -10)

	P0 = make_parameter_field(mesh.coordinates(), ode)
	P1 = make_parameter_field(mesh.coordinates(), ode, ist=ist) 

	ind_stim = P1[:,1]!=0
	print "P0", P0[P0[:,1]!=0.,1]
	print "P1", P1[ind_stim,1]

	solver = GRL2()
예제 #29
0
from gotran import load_ode
from goss import *
ode =  jit(load_ode('myocyte.ode'))

def f(x):
	return x

class Tull:
	def __init__(self):
		self.a = 2



a = f

b = Tull()
print isinstance(b, Tull)
예제 #30
0
def pacing_cv(ode, BCL_range, D, L, h, dt):
	"""
	Pace a 0D cell model for 50 cycles at given BCL, then
	simulate a wave in a 1D strand. 5 beats are initiated 
	at the left hand side of the strand, after the fifth 
	beat, the conduction velocity is measured.
	"""

	# Create domain
	N = int(L/h) # number of nodes - 1
	domain = IntervalMesh(N, 0, L)

	# Define functions
	V = FunctionSpace(domain, 'Lagrange', 1)
	u = TrialFunction(V)
	up = Function(V)
	v = TestFunction(V)

	# Assemble the mass matrix and the stiffness matrix
	M = assemble(u*v*dx)
	K = assemble(Constant(dt)*inner(D*grad(u), grad(v))*dx)
	A = M + K
	pde_solver = LUSolver(A)
	pde_solver.parameters["reuse_factorization"] = True

	# Set up ODESolver
	ode = load_ode(ode)
	compiled_ode = dolfin_jit(ode, field_states=["V"], field_parameters = [
							  "stim_period"])
	
	params = DOLFINODESystemSolver.default_parameters()
	params["scheme"] = "RL1"
	stim_field = interpolate(Expression("near(x[0],0)*sd", sd=1.), V)
	#compiled_ode.set_parameter("stim_duration", stim_field)
	BCL = 1000
	#compiled_ode.set_parameter("stim_period", BCL)

	ode_solver = DOLFINODESystemSolver(domain, compiled_ode, params=params)
	solver = ode_solver._ode_system_solvers[0]

	try:
		states = np.load("../data/steadystate_%s_BCL%d.npy" % (ode, BCL))
	except:
		states = find_steadystate(BCL, 50, dt, ode, plot_results=False)

	for i in range(N+1):
		solver.states(i)[:] = states
		up.vector()[i] = solver.states(i)[0] 

	#params = np.zeros((N+1)*2)
	solver.set_field_parameters(np.array([BCL]+[0]*N, dtype=np.float_))
	#stim_field = interpolate(Expression("0"), V)
	#compiled_ode.set_parameter("stim_duration", stim_field)



	t = 0.
	tstop = 1e6
	while t < tstop:
		# Step ODE solver
		ode_solver.step((t, t+dt), up)

		# Assemble RHS and solve pde
		b = M*up.vector()
		pde_solver.solve(up.vector(), b)

		# Plot solution
		plot(up, range_min=-80.0, range_max=40.0)


		t += dt
예제 #31
0
파일: spiral.py 프로젝트: jvbrink/atriapace
def spiral(ode, domain, D, S1, S2, plot_args={}, save_fig=False, threshold=0, BCL=500):
    # Load the cell model from file
    cellmodel = load_ode(ode)
    cellmodel = dolfin_jit(cellmodel, field_states, field_parameters)

    # Create the CardiacModel for the given domain and cell model
    heart = CardiacModel(domain, Constant(0.0), D, None, cellmodel)

    # Create the solver
    solver = GOSSplittingSolver(heart, GOSSparams())
    dolfin_solver = solver.ode_solver
    ode_solver = dolfin_solver._ode_system_solvers[0]

    # Calculate parameter fields
    V = VectorFunctionSpace(domain, 'CG', 1)
    S1 = interpolate(S1, V).vector().array()
    S2 = interpolate(S2, V).vector().array()
    nostim = np.zeros(S1.shape, dtype=np.float_)

    # Read in steady state from file
    try:
        sspath = "../data/steadystates/"
        steadystate = np.load(sspath+"steadystate_%s_BCL%d.npy" % (ode, BCL))
    except:
        print "Did not find steadystate for %s at BCL: %g, pacing 0D model" % (ode, BCL)
        steadystate = find_steadystate(ODE, BCL, 0.01)

    # Load steadystate into 2D model
    (u, um) = solver.solution_fields()
    for node in range(domain.coordinates().shape[0]):
        ode_solver.states(node)[:] = steadystate
        u.vector()[node] = steadystate[0]

    cnt = 0
    # Stimulate S1 pulse
    plot(u, interactive=True, **plot_args)
    ode_solver.set_field_parameters(S1)
    for timestep, (u, vm) in solver.solve((0, S2_time), dt):
        fig = plot(u, interactive=False, **plot_args)

        if save_fig and cnt % 5 == 0:
            padded_index = '%08d' % cnt
            fig.write_png('../tmp/spiral_%s' % ode + padded_index)
        cnt += 1

    ode_solver.set_field_parameters(S2)

    # Stimulate S2 pulse
    for timestep, (u, vm) in solver.solve((S2_time, S2_time+100), dt):
        fig = plot(u, interactive=False, **plot_args)

        if save_fig and cnt % 5 == 0:    
            padded_index = '%08d' % cnt
            fig.write_png('../tmp/spiral_%s' % ode + padded_index)
        cnt += 1

    # Turn of all stimulus and iterate until simulation stop
    # We monitor the top right corner to find the spiral period
    ode_solver.set_field_parameters(nostim)
    trig_time = []
    xp = 0;
    for timestep, (u, vm) in solver.solve((S2_time+100, tstop), dt):
        fig = plot(u, interactive=False, **plot_args)

        x = u.vector()[-1]
        if x > threshold and (x-threshold)*(xp-threshold) < 0:
            trig_time.append(timestep[0])
            print "Trigger: %g" % timestep[0]
        xp = x

        if save_fig and cnt % 5 == 0:
            padded_index = '%08d' % cnt
            fig.write_png('../tmp/spiral_%s' % ode + padded_index)
        cnt += 1