def test_Horsager2009Temporal():
    model = Horsager2009Temporal()
    # User can set their own params:
    model.dt = 0.1
    npt.assert_equal(model.dt, 0.1)
    model.build(dt=1e-4)
    npt.assert_equal(model.dt, 1e-4)
    # User cannot add more model parameters:
    with pytest.raises(FreezeError):
        model.rho = 100

    # Nothing in, None out:
    implant = ProsthesisSystem(PointSource(0, 0, 0))
    npt.assert_equal(model.predict_percept(implant.stim), None)

    # Zero in = zero out:
    implant.stim = np.zeros((1, 6))
    percept = model.predict_percept(implant.stim, t_percept=[0, 1, 2])
    npt.assert_equal(isinstance(percept, Percept), True)
    npt.assert_equal(percept.shape, (1, 1, 3))
    npt.assert_almost_equal(percept.data, 0)

    # Can't request the same time more than once (this would break the Cython
    # loop, because `idx_frame` is incremented after a write; also doesn't
    # make much sense):
    with pytest.raises(ValueError):
        implant.stim = np.ones((1, 100))
        model.predict_percept(implant.stim, t_percept=[0.2, 0.2])

    # Single-pulse brightness from Fig.3:
    model = Horsager2009Temporal().build()
    for amp, pdur in zip([188.077, 89.74, 10.55], [0.075, 0.15, 4.0]):
        stim = BiphasicPulse(amp,
                             pdur,
                             interphase_dur=pdur,
                             stim_dur=200,
                             cathodic_first=True)
        t_percept = np.arange(0, stim.time[-1] + model.dt / 2, model.dt)
        percept = model.predict_percept(stim, t_percept=t_percept)
        npt.assert_almost_equal(percept.data.max(), 110.3, decimal=2)

    # Fixed-duration brightness from Fig.4:
    model = Horsager2009Temporal().build()
    for amp, freq in zip([136.02, 120.35, 57.71], [5, 15, 225]):
        stim = BiphasicPulseTrain(freq,
                                  amp,
                                  0.075,
                                  interphase_dur=0.075,
                                  stim_dur=200,
                                  cathodic_first=True)
        t_percept = np.arange(0, stim.time[-1] + model.dt / 2, model.dt)
        percept = model.predict_percept(stim, t_percept=t_percept)
        npt.assert_almost_equal(percept.data.max(), 36.3, decimal=2)
Beispiel #2
0
def test_Horsager2009Model():
    model = Horsager2009Model()
    npt.assert_equal(hasattr(model, 'has_space'), True)
    npt.assert_equal(model.has_space, False)
    npt.assert_equal(hasattr(model, 'has_time'), True)
    npt.assert_equal(model.has_time, True)

    # User can set `dt`:
    model.temporal.dt = 1e-5
    npt.assert_almost_equal(model.dt, 1e-5)
    npt.assert_almost_equal(model.temporal.dt, 1e-5)
    model.build(dt=3e-4)
    npt.assert_almost_equal(model.dt, 3e-4)
    npt.assert_almost_equal(model.temporal.dt, 3e-4)

    # User cannot add more model parameters:
    with pytest.raises(FreezeError):
        model.rho = 100

    # Model and TemporalModel give the same result
    for amp, freq in zip([136.02, 120.35, 57.71], [5, 15, 225]):
        stim = BiphasicPulseTrain(freq, amp, 0.075, interphase_dur=0.075,
                                  stim_dur=200, cathodic_first=True)
        model1 = Horsager2009Model().build()
        model2 = Horsager2009Temporal().build()
        implant = ProsthesisSystem(PointSource(0, 0, 0), stim=stim)
        npt.assert_almost_equal(model1.predict_percept(implant).data,
                                model2.predict_percept(stim).data)
pulse = BiphasicPulse(180,
                      phase_dur,
                      interphase_dur=phase_dur,
                      stim_dur=stim_dur,
                      cathodic_first=True)
pulse.plot(time=np.linspace(0, 10, num=10000))

###############################################################################
# Simulating the model response
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# The model's response to this stimulus can be visualized as follows:

from pulse2percept.models import Horsager2009Temporal

model = Horsager2009Temporal()
model.build()

percept = model.predict_percept(pulse)

max_bright = percept.data.max()

import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(pulse.time,
        -20 + 10 * pulse.data[0, :] / pulse.data.max(),
        linewidth=3,
        label='pulse')
ax.plot(percept.time, percept.data[0, 0, :], linewidth=3, label='percept')
ax.plot([0, stim_dur], [max_bright, max_bright], 'k--', label='max brightness')