def test_predict_batched(engine):
    if not has_jax:
        pytest.skip("Jax not installed")

    # Allows mix of valid Stimulus types
    stims = [{
        'A5': BiphasicPulseTrain(25, 4, 0.45),
        'C7': BiphasicPulseTrain(50, 2.5, 0.75)
    }, {
        'B4': BiphasicPulseTrain(3, 1, 0.32)
    },
             Stimulus({'F3': BiphasicPulseTrain(12, 3, 1.2)})]
    implant = ArgusII()
    model = BiphasicAxonMapModel(engine=engine, xystep=2)
    model.build()
    # Import error if we dont have jax
    if engine != 'jax':
        with pytest.raises(ImportError):
            model.predict_percept_batched(implant, stims)
        return

    percepts_batched = model.predict_percept_batched(implant, stims)
    percepts_serial = []
    for stim in stims:
        implant.stim = stim
        percepts_serial.append(model.predict_percept(implant))

    npt.assert_equal(len(percepts_serial), len(percepts_batched))
    for p1, p2 in zip(percepts_batched, percepts_serial):
        npt.assert_almost_equal(p1.data, p2.data)
def test_predict_spatial_jax():
    # ensure jax predict spatial is equal to normal
    if not has_jax:
        pytest.skip("Jax not installed")
    model1 = BiphasicAxonMapModel(engine='jax', xystep=2)
    model2 = BiphasicAxonMapModel(engine='cython', xystep=2)
    model1.build()
    model2.build()
    implant = ArgusII()
    implant.stim = {
        'A5': BiphasicPulseTrain(25, 4, 0.45),
        'C7': BiphasicPulseTrain(50, 2.5, 0.75)
    }
    p1 = model1.predict_percept(implant)
    p2 = model2.predict_percept(implant)
    npt.assert_almost_equal(p1.data, p2.data, decimal=4)

    # test changing model parameters, make sure jax is clearing cache on build
    model1.axlambda = 800
    model2.axlambda = 800
    model1.rho = 50
    model2.rho = 50
    model1.build()
    model2.build()
    p1 = model1.predict_percept(implant)
    p2 = model2.predict_percept(implant)
    npt.assert_almost_equal(p1.data, p2.data, decimal=4)
def test_BiphasicTripletTrain(amp, interphase_dur, interpulse_dur, delay_dur,
                              cathodic_first):
    freq = 23.456
    stim_dur = 657.456
    phase_dur = 2
    window_dur = 1000.0 / freq
    n_pulses = int(freq * stim_dur / 1000.0)
    mid_first_pulse = delay_dur + phase_dur / 2.0
    mid_interphase = delay_dur + phase_dur + interphase_dur / 2.0
    mid_second_pulse = delay_dur + interphase_dur + 1.5 * phase_dur
    mid_interpulse = delay_dur + 2.0*phase_dur + \
        interphase_dur + interpulse_dur / 2.0
    first_amp = -np.abs(amp) if cathodic_first else np.abs(amp)
    second_amp = -first_amp

    # Basic usage:
    pt = BiphasicTripletTrain(freq,
                              amp,
                              phase_dur,
                              interphase_dur=interphase_dur,
                              interpulse_dur=interpulse_dur,
                              delay_dur=delay_dur,
                              stim_dur=stim_dur,
                              cathodic_first=cathodic_first)
    for i in range(n_pulses):
        t_win = i * window_dur
        npt.assert_almost_equal(pt[0, np.floor(t_win)], 0)
        npt.assert_almost_equal(pt[0, t_win + mid_first_pulse], first_amp)
        if interphase_dur > 0:
            npt.assert_almost_equal(pt[0, t_win + mid_interphase], 0)
        npt.assert_almost_equal(pt[0, t_win + mid_second_pulse], second_amp)
        if interpulse_dur > 0:
            npt.assert_almost_equal(pt[0, mid_interpulse], 0)
    npt.assert_almost_equal(pt.time[0], 0)
    npt.assert_almost_equal(pt.time[-1], stim_dur, decimal=2)
    npt.assert_equal(pt.cathodic_first, cathodic_first)
    npt.assert_equal(pt.is_charge_balanced,
                     np.isclose(np.trapz(pt.data, pt.time)[0], 0, atol=1e-5))

    # Zero frequency:
    pt = BiphasicPulseTrain(0, amp, phase_dur)
    npt.assert_almost_equal(pt.time, [0, 1000])
    npt.assert_almost_equal(pt.data, 0)
    # Zero amp:
    pt = BiphasicPulseTrain(freq, 0, phase_dur)
    npt.assert_almost_equal(pt.data, 0)

    # Pulse can fill the entire window (no "unique time points" error):
    pt = BiphasicTripletTrain(10, 20, 100 / 6.001, stim_dur=500)
    npt.assert_almost_equal(pt.time[-1], 500)
    npt.assert_equal(np.round(np.trapz(np.abs(pt.data), pt.time)[0]), 9998)

    # Specific number of pulses
    for n_pulses in [2, 4, 5]:
        pt = BiphasicPulseTrain(500, 30, 0.05, n_pulses=n_pulses, stim_dur=19)
        npt.assert_almost_equal(np.sum(np.isclose(pt.data, 30)), 2 * n_pulses)
        npt.assert_almost_equal(pt.time[-1], 19)
def test_Nanduri2012Temporal():
    model = Nanduri2012Temporal()
    # User can set their own params:
    model.dt = 0.1
    npt.assert_equal(model.dt, 0.1)
    model.build(dt=1e-4)
    npt.assert_equal(model.dt, 1e-4)
    # User cannot add more model parameters:
    with pytest.raises(FreezeError):
        model.rho = 100

    # Nothing in, None out:
    npt.assert_equal(model.predict_percept(ArgusI().stim), None)

    # Zero in = zero out:
    implant = ArgusI(stim=np.zeros((16, 100)))
    percept = model.predict_percept(implant.stim, t_percept=[0, 1, 2])
    npt.assert_equal(isinstance(percept, Percept), True)
    npt.assert_equal(percept.shape, (16, 1, 3))
    npt.assert_almost_equal(percept.data, 0)

    # Can't request the same time more than once (this would break the Cython
    # loop, because `idx_frame` is incremented after a write; also doesn't
    # make much sense):
    with pytest.raises(ValueError):
        implant.stim = np.ones((16, 100))
        model.predict_percept(implant.stim, t_percept=[0.2, 0.2])

    # Brightness scales differently with amplitude vs frequency:
    model = Nanduri2012Temporal(dt=5e-3)
    model.build()
    sdur = 1000.0  # stimulus duration (ms)
    pdur = 0.45  # (ms)
    t_percept = np.arange(0, sdur, 5)
    implant = ProsthesisSystem(ElectrodeArray(DiskElectrode(0, 0, 0, 260)))
    bright_amp = []
    for amp in np.linspace(0, 50, 5):
        # implant.stim = PulseTrain(model.dt, freq=20, amp=amp, dur=sdur,
        #                           pulse_dur=pdur, interphase_dur=pdur)
        implant.stim = BiphasicPulseTrain(20, amp, pdur, interphase_dur=pdur,
                                          stim_dur=sdur)
        percept = model.predict_percept(implant.stim, t_percept=t_percept)
        bright_amp.append(percept.data.max())
    bright_amp_ref = [0.0, 0.00890, 0.0657, 0.1500, 0.1691]
    npt.assert_almost_equal(bright_amp, bright_amp_ref, decimal=3)

    bright_freq = []
    for freq in np.linspace(0, 100, 5):
        # implant.stim = PulseTrain(model.dt, freq=freq, amp=20, dur=sdur,
        #                           pulse_dur=pdur, interphase_dur=pdur)
        implant.stim = BiphasicPulseTrain(freq, 20, pdur, interphase_dur=pdur,
                                          stim_dur=sdur)
        percept = model.predict_percept(implant.stim, t_percept=t_percept)
        bright_freq.append(percept.data.max())
    bright_freq_ref = [0.0, 0.0394, 0.0741, 0.1073, 0.1385]
    npt.assert_almost_equal(bright_freq, bright_freq_ref, decimal=3)
def test_BiphasicTripletTrain(amp, interphase_dur, delay_dur, cathodic_first):
    freq = 23.456
    stim_dur = 657.456
    phase_dur = 2
    window_dur = 1000.0 / freq
    n_pulses = int(freq * stim_dur / 1000.0)
    dt = 1e-6
    mid_first_pulse = delay_dur + phase_dur / 2.0
    mid_interphase = delay_dur + phase_dur + interphase_dur / 2.0
    mid_second_pulse = delay_dur + interphase_dur + 1.5 * phase_dur
    first_amp = -np.abs(amp) if cathodic_first else np.abs(amp)
    second_amp = -first_amp

    # Basic usage:
    pt = BiphasicTripletTrain(freq,
                              amp,
                              phase_dur,
                              interphase_dur=interphase_dur,
                              delay_dur=delay_dur,
                              stim_dur=stim_dur,
                              cathodic_first=cathodic_first,
                              dt=dt)
    for i in range(n_pulses):
        t_win = i * window_dur
        npt.assert_almost_equal(pt[0, t_win], 0)
        npt.assert_almost_equal(pt[0, t_win + mid_first_pulse], first_amp)
        if interphase_dur > 0:
            npt.assert_almost_equal(pt[0, t_win + mid_interphase], 0)
        npt.assert_almost_equal(pt[0, t_win + mid_second_pulse], second_amp)
    npt.assert_almost_equal(pt.time[0], 0)
    npt.assert_almost_equal(pt.time[-1], stim_dur, decimal=2)
    npt.assert_equal(pt.cathodic_first, cathodic_first)
    npt.assert_equal(pt.charge_balanced,
                     np.isclose(np.trapz(pt.data, pt.time)[0], 0, atol=1e-5))

    # Zero frequency:
    pt = BiphasicPulseTrain(0, amp, phase_dur)
    npt.assert_almost_equal(pt.time, [0, 1000])
    npt.assert_almost_equal(pt.data, 0)
    # Zero amp:
    pt = BiphasicPulseTrain(freq, 0, phase_dur)
    npt.assert_almost_equal(pt.data, 0)

    # Specific number of pulses
    for n_pulses in [2, 4, 5]:
        pt = BiphasicPulseTrain(500,
                                30,
                                0.05,
                                n_pulses=n_pulses,
                                stim_dur=19,
                                dt=0.05)
        npt.assert_almost_equal(np.sum(np.isclose(pt.data, 30)), n_pulses)
        npt.assert_almost_equal(pt.time[-1], 19)
def test_merge_time_axes_merge_tolerance():
    # Test issue where not enough unique points were collected
    # Leading to interpolation to corrupt stimuli data.
    # See: https://github.com/pulse2percept/pulse2percept/issues/392
    a = BiphasicPulseTrain(20, 1, 0.45)
    b = BiphasicPulseTrain(30, 1, 0.45)

    stim = Stimulus({"A2": a, "A10": b})
    unique_points = np.unique(stim.data)

    # Assert no value goes close to 1/3 or -1/3, i.e. a corrupted data point
    npt.assert_equal(np.isclose(1 / 3, unique_points, atol=0.1).any(), False)
    npt.assert_equal(np.isclose(-1 / 3, unique_points, atol=0.1).any(), False)
Exemple #7
0
def test_Horsager2009Model():
    model = Horsager2009Model()
    npt.assert_equal(hasattr(model, 'has_space'), True)
    npt.assert_equal(model.has_space, False)
    npt.assert_equal(hasattr(model, 'has_time'), True)
    npt.assert_equal(model.has_time, True)

    # User can set `dt`:
    model.temporal.dt = 1e-5
    npt.assert_almost_equal(model.dt, 1e-5)
    npt.assert_almost_equal(model.temporal.dt, 1e-5)
    model.build(dt=3e-4)
    npt.assert_almost_equal(model.dt, 3e-4)
    npt.assert_almost_equal(model.temporal.dt, 3e-4)

    # User cannot add more model parameters:
    with pytest.raises(FreezeError):
        model.rho = 100

    # Model and TemporalModel give the same result
    for amp, freq in zip([136.02, 120.35, 57.71], [5, 15, 225]):
        stim = BiphasicPulseTrain(freq, amp, 0.075, interphase_dur=0.075,
                                  stim_dur=200, cathodic_first=True)
        model1 = Horsager2009Model().build()
        model2 = Horsager2009Temporal().build()
        implant = ProsthesisSystem(PointSource(0, 0, 0), stim=stim)
        npt.assert_almost_equal(model1.predict_percept(implant).data,
                                model2.predict_percept(stim).data)
def test_metadata():
    stim = BiphasicPulseTrain(10, 10, 1, metadata='userdata')
    npt.assert_equal(stim.metadata['user'], 'userdata')
    npt.assert_equal(stim.metadata['freq'], 10)
    npt.assert_equal(stim.metadata['amp'], 10)
    npt.assert_equal(stim.metadata['phase_dur'], 1)
    npt.assert_equal(stim.metadata['delay_dur'], 0)

    stim = Stimulus(
        {
            'A2': BiphasicPulseTrain(10, 10, 1, metadata='userdataA2'),
            'B1': BiphasicPulseTrain(11, 9, 2, metadata='userdataB1'),
            'C3': BiphasicPulseTrain(12, 8, 3, metadata='userdataC3')
        },
        metadata='stimulus_userdata')
    npt.assert_equal(stim.metadata['user'], 'stimulus_userdata')
    npt.assert_equal(stim.metadata['electrodes']['A2']['type'],
                     BiphasicPulseTrain)
    npt.assert_equal(stim.metadata['electrodes']['B1']['metadata']['freq'], 11)
    npt.assert_equal(stim.metadata['electrodes']['C3']['metadata']['user'],
                     'userdataC3')
def test_Horsager2009Temporal():
    model = Horsager2009Temporal()
    # User can set their own params:
    model.dt = 0.1
    npt.assert_equal(model.dt, 0.1)
    model.build(dt=1e-4)
    npt.assert_equal(model.dt, 1e-4)
    # User cannot add more model parameters:
    with pytest.raises(FreezeError):
        model.rho = 100

    # Nothing in, None out:
    implant = ProsthesisSystem(PointSource(0, 0, 0))
    npt.assert_equal(model.predict_percept(implant.stim), None)

    # Zero in = zero out:
    implant.stim = np.zeros((1, 6))
    percept = model.predict_percept(implant.stim, t_percept=[0, 1, 2])
    npt.assert_equal(isinstance(percept, Percept), True)
    npt.assert_equal(percept.shape, (1, 1, 3))
    npt.assert_almost_equal(percept.data, 0)

    # Can't request the same time more than once (this would break the Cython
    # loop, because `idx_frame` is incremented after a write; also doesn't
    # make much sense):
    with pytest.raises(ValueError):
        implant.stim = np.ones((1, 100))
        model.predict_percept(implant.stim, t_percept=[0.2, 0.2])

    # Single-pulse brightness from Fig.3:
    model = Horsager2009Temporal().build()
    for amp, pdur in zip([188.077, 89.74, 10.55], [0.075, 0.15, 4.0]):
        stim = BiphasicPulse(amp,
                             pdur,
                             interphase_dur=pdur,
                             stim_dur=200,
                             cathodic_first=True)
        t_percept = np.arange(0, stim.time[-1] + model.dt / 2, model.dt)
        percept = model.predict_percept(stim, t_percept=t_percept)
        npt.assert_almost_equal(percept.data.max(), 110.3, decimal=2)

    # Fixed-duration brightness from Fig.4:
    model = Horsager2009Temporal().build()
    for amp, freq in zip([136.02, 120.35, 57.71], [5, 15, 225]):
        stim = BiphasicPulseTrain(freq,
                                  amp,
                                  0.075,
                                  interphase_dur=0.075,
                                  stim_dur=200,
                                  cathodic_first=True)
        t_percept = np.arange(0, stim.time[-1] + model.dt / 2, model.dt)
        percept = model.predict_percept(stim, t_percept=t_percept)
        npt.assert_almost_equal(percept.data.max(), 36.3, decimal=2)
Generating a pulse train
------------------------

The first step is to build a pulse train using the
:py:class:`~pulse2percept.stimuli.PulseTrain` class.
We want to generate a 20Hz pulse train (0.45ms pulse duration, cathodic-first)
at 30uA that lasts for a second:

"""
# sphinx_gallery_thumbnail_number = 4
from pulse2percept.stimuli import BiphasicPulseTrain, Stimulus
tsample = 0.005  # sampling time step (ms)
phase_dur = 0.45  # duration of the cathodic/anodic phase (ms)
stim_dur = 1000  # stimulus duration (ms)
amp_th = 30  # threshold current (uA)
stim = BiphasicPulseTrain(20, amp_th, phase_dur, interphase_dur=phase_dur,
                          stim_dur=stim_dur)

# Configure Matplotlib:
import matplotlib.pyplot as plt
plt.style.use('ggplot')
from matplotlib import rc
rc('font', size=12)

# Plot the stimulus in the range t=[0, 60] ms:
stim.plot(time=(0, 60))

###############################################################################
# Creating an implant
# -------------------
#
# Before we can run the Nanduri model, we need to create a retinal implant to
Exemple #11
0
def test_Nanduri2012Model_predict_percept():
    # Nothing in = nothing out:
    model = Nanduri2012Model(xrange=(0, 0), yrange=(0, 0), engine='serial')
    model.build()
    implant = ArgusI(stim=None)
    npt.assert_equal(model.predict_percept(implant), None)
    implant.stim = np.zeros(16)
    npt.assert_almost_equal(model.predict_percept(implant).data, 0)

    # Single-pixel model same as TemporalModel:
    implant = ProsthesisSystem(DiskElectrode(0, 0, 0, 100))
    # implant.stim = PulseTrain(5e-6)
    implant.stim = BiphasicPulseTrain(20, 20, 0.45, interphase_dur=0.45)
    t_percept = [0, 0.01, 1.0]
    percept = model.predict_percept(implant, t_percept=t_percept)
    temp = Nanduri2012Temporal().build()
    temp = temp.predict_percept(implant.stim, t_percept=t_percept)
    npt.assert_almost_equal(percept.data, temp.data, decimal=4)

    # Only works for DiskElectrode arrays:
    with pytest.raises(TypeError):
        implant = ProsthesisSystem(ElectrodeArray(PointSource(0, 0, 0)))
        implant.stim = 1
        model.predict_percept(implant)
    with pytest.raises(TypeError):
        implant = ProsthesisSystem(
            ElectrodeArray(
                [DiskElectrode(0, 0, 0, 100),
                 PointSource(100, 100, 0)]))
        implant.stim = [1, 1]
        model.predict_percept(implant)

    # Requested times must be multiples of model.dt:
    implant = ProsthesisSystem(ElectrodeArray(DiskElectrode(0, 0, 0, 260)))
    # implant.stim = PulseTrain(tsample)
    implant.stim = BiphasicPulseTrain(20, 20, 0.45)
    model.temporal.dt = 0.1
    with pytest.raises(ValueError):
        model.predict_percept(implant, t_percept=[0.01])
    with pytest.raises(ValueError):
        model.predict_percept(implant, t_percept=[0.01, 1.0])
    with pytest.raises(ValueError):
        model.predict_percept(implant, t_percept=np.arange(0, 0.5, 0.101))
    model.predict_percept(implant, t_percept=np.arange(0, 0.5, 1.0000001))

    # Can't request the same time more than once (this would break the Cython
    # loop, because `idx_frame` is incremented after a write; also doesn't
    # make much sense):
    with pytest.raises(ValueError):
        model.predict_percept(implant, t_percept=[0.2, 0.2])

    # It's ok to extrapolate beyond `stim` if the `extrapolate` flag is set:
    model.temporal.dt = 1e-2
    npt.assert_almost_equal(
        model.predict_percept(implant, t_percept=10000).data, 0)

    # Output shape must be determined by t_percept:
    npt.assert_equal(
        model.predict_percept(implant, t_percept=0).shape, (1, 1, 1))
    npt.assert_equal(
        model.predict_percept(implant, t_percept=[0, 1]).shape, (1, 1, 2))

    # Brightness vs. size (use values from Nanduri paper):
    model = Nanduri2012Model(xystep=0.5, xrange=(-4, 4), yrange=(-4, 4))
    model.build()
    implant = ProsthesisSystem(ElectrodeArray(DiskElectrode(0, 0, 0, 260)))
    amp_th = 30
    bright_th = 0.107
    stim_dur = 1000.0
    pdur = 0.45
    t_percept = np.arange(0, stim_dur, 5)
    amp_factors = [1, 6]
    frames_amp = []
    for amp_f in amp_factors:
        implant.stim = BiphasicPulseTrain(20,
                                          amp_f * amp_th,
                                          pdur,
                                          interphase_dur=pdur,
                                          stim_dur=stim_dur)
        percept = model.predict_percept(implant, t_percept=t_percept)
        idx_frame = np.argmax(np.max(percept.data, axis=(0, 1)))
        brightest_frame = percept.data[..., idx_frame]
        frames_amp.append(brightest_frame)
    npt.assert_equal([np.sum(f > bright_th) for f in frames_amp], [0, 161])
    freqs = [20, 120]
    frames_freq = []
    for freq in freqs:
        implant.stim = BiphasicPulseTrain(freq,
                                          1.25 * amp_th,
                                          pdur,
                                          interphase_dur=pdur,
                                          stim_dur=stim_dur)
        percept = model.predict_percept(implant, t_percept=t_percept)
        idx_frame = np.argmax(np.max(percept.data, axis=(0, 1)))
        brightest_frame = percept.data[..., idx_frame]
        frames_freq.append(brightest_frame)
    npt.assert_equal([np.sum(f > bright_th) for f in frames_freq], [21, 49])
# The same procedure can be repeated for
# :py:class:`~pulse2percept.stimuli.BiphasicPulseTrain` stimuli to reproduce
# Fig. 4.

from pulse2percept.stimuli import BiphasicPulseTrain

# Load the data:
fixed_dur = data[(data.stim_type == 'fixed_duration') & (data.subject == 'S05')
                 & (data.electrode == 'C3') & (data.pulse_dur == 0.075)]

# Find the threshold:
amp_th = []
for _, row in fixed_dur.iterrows():
    stim = BiphasicPulseTrain(row['stim_freq'],
                              1,
                              row['pulse_dur'],
                              interphase_dur=row['interphase_dur'],
                              stim_dur=row['stim_dur'],
                              cathodic_first=True)
    amp_th.append(
        model.find_threshold(stim,
                             row['theta'],
                             amp_range=(0, 300),
                             amp_tol=1,
                             bright_tol=0.1))

plt.semilogx(fixed_dur.stim_freq, fixed_dur.stim_amp, 's', label='data')
plt.semilogx(fixed_dur.stim_freq, amp_th, 'k-', linewidth=3, label='model')
plt.xticks([5, 15, 75, 225])
plt.xlabel('frequency (Hz)')
plt.ylabel('threshold current (uA)')
plt.legend()
Exemple #13
0
A series of biphasic pulses can be created with the
:py:class:`~pulse2percept.stimuli.BiphasicPulseTrain` class.

You have the same options as when setting up a single
:py:class:`~pulse2percept.stimuli.BiphasicPulse`, in addition to specifying
a pulse train frequency (``freq``) and total stimulus duration (``stim_dur``).

For example, a 20 Hz pulse train lasting 200 ms and made from anodic-first
biphasic pulses (30 uA, 2 ms pulse duration, no interphase gap) can be
created as follows:
"""
# sphinx_gallery_thumbnail_number = 4

from pulse2percept.stimuli import BiphasicPulseTrain

pt = BiphasicPulseTrain(20, 30, 2, stim_dur=200, cathodic_first=False)
pt.plot()

###############################################################################
# You can also limit the number of pulses in the train, but still make the
# stimulus last 200 ms:

pt = BiphasicPulseTrain(20,
                        30,
                        2,
                        n_pulses=3,
                        stim_dur=200,
                        cathodic_first=False)
pt.plot()

###############################################################################
def test_biphasicAxonMapSpatial(engine):
    # Lambda cannot be too small:
    with pytest.raises(ValueError):
        BiphasicAxonMapSpatial(axlambda=9).build()

    model = BiphasicAxonMapModel(engine=engine, xystep=2).build()
    # Jax not implemented yet
    if engine == 'jax':
        with pytest.raises(NotImplementedError):
            implant = ArgusII()
            implant.stim = Stimulus({'A5': BiphasicPulseTrain(20, 1, 0.45)})
            percept = model.predict_percept(implant)
        return

    # Only accepts biphasic pulse trains with no delay dur
    implant = ArgusI(stim=np.ones(16))
    with pytest.raises(TypeError):
        model.predict_percept(implant)

    # Nothing in, None out:
    npt.assert_equal(model.predict_percept(ArgusI()), None)

    # Zero in = zero out:
    implant = ArgusI(stim=np.zeros(16))
    percept = model.predict_percept(implant)
    npt.assert_equal(isinstance(percept, Percept), True)
    npt.assert_equal(percept.shape, list(model.grid.x.shape) + [1])
    npt.assert_almost_equal(percept.data, 0)
    npt.assert_equal(percept.time, None)

    # Should be equal to axon map model if effects models return 1
    model = BiphasicAxonMapSpatial(engine=engine, xystep=2)

    def bright_model(freq, amp, pdur):
        return 1

    def size_model(freq, amp, pdur):
        return 1

    def streak_model(freq, amp, pdur):
        return 1

    model.bright_model = bright_model
    model.size_model = size_model
    model.streak_model = streak_model
    model.build()
    axon_map = AxonMapSpatial(xystep=2).build()
    implant = ArgusII()
    implant.stim = Stimulus({'A5': BiphasicPulseTrain(20, 1, 0.45)})
    percept = model.predict_percept(implant)
    percept_axon = axon_map.predict_percept(implant)
    npt.assert_almost_equal(percept.data[:, :, 0],
                            percept_axon.get_brightest_frame())

    # Effect models must be callable
    model = BiphasicAxonMapSpatial(engine=engine, xystep=2)
    model.bright_model = 1.0
    with pytest.raises(TypeError):
        model.build()

    # If t_percept is not specified, there should only be one frame
    model = BiphasicAxonMapSpatial(engine=engine, xystep=2)
    model.build()
    implant = ArgusII()
    implant.stim = Stimulus({'A5': BiphasicPulseTrain(20, 1, 0.45)})
    percept = model.predict_percept(implant)
    npt.assert_equal(percept.time is None, True)
    # If t_percept is specified, only first frame should have data
    # and the rest should be empty
    percept = model.predict_percept(implant, t_percept=[0, 1, 2, 5, 10])
    npt.assert_equal(len(percept.time), 5)
    npt.assert_equal(np.any(percept.data[:, :, 0]), True)
    npt.assert_equal(np.any(percept.data[:, :, 1:]), False)

    # Test that default models give expected values
    model = BiphasicAxonMapSpatial(engine=engine,
                                   rho=400,
                                   axlambda=600,
                                   xystep=1,
                                   xrange=(-20, 20),
                                   yrange=(-15, 15))
    model.build()
    implant = ArgusII()
    implant.stim = Stimulus({'A4': BiphasicPulseTrain(20, 1, 1)})
    percept = model.predict_percept(implant)
    npt.assert_equal(np.sum(percept.data > 0.0813), 81)
    npt.assert_equal(np.sum(percept.data > 0.1626), 59)
    npt.assert_equal(np.sum(percept.data > 0.2439), 44)
    npt.assert_equal(np.sum(percept.data > 0.4065), 26)
    npt.assert_equal(np.sum(percept.data > 0.5691), 14)
Exemple #15
0
implant.plot()
plt.show()

##############################################################################
# As mentioned above, the Biphasic Axon Map Model only accepts
# :py:class:`~pulse2percept.stimuli.BiphasicPulseTrain`
# stimuli with no :py:attr:`~pulse2percept.stimuli.BiphasicPulseTrain.delay_dur`.
# The amplitude given to the BiphasicPulseTrain
# is interpreted as amplitude as a factor of threshold (i.e. an amp of 1 means
# 1xTh)
#
# You can easily assign BiphasicPulseTrains to electrodes with a dictionary
# The following creates a train with 20Hz frequency, 1xTh amplitude, and 0.45ms
# pulse / phase duration.

implant.stim = {'A4': BiphasicPulseTrain(20, 1, 0.45)}
implant.stim.plot()

##############################################################################
# Finally, you can predict the percept resulting from stimulation

percept = model.predict_percept(implant)
ax = percept.plot()
ax.set_title('Predicted percept')
plt.show()
##############################################################################
# Increasing the frequency will make phosphenes brighter
fig, axes = plt.subplots(1, 2, sharex=True, sharey=True)
implant.stim = {'A4': BiphasicPulseTrain(50, 1, 0.45)}
new_percept = model.predict_percept(implant)
new_percept.plot(ax=axes[1])