def test_predict_batched(engine): if not has_jax: pytest.skip("Jax not installed") # Allows mix of valid Stimulus types stims = [{ 'A5': BiphasicPulseTrain(25, 4, 0.45), 'C7': BiphasicPulseTrain(50, 2.5, 0.75) }, { 'B4': BiphasicPulseTrain(3, 1, 0.32) }, Stimulus({'F3': BiphasicPulseTrain(12, 3, 1.2)})] implant = ArgusII() model = BiphasicAxonMapModel(engine=engine, xystep=2) model.build() # Import error if we dont have jax if engine != 'jax': with pytest.raises(ImportError): model.predict_percept_batched(implant, stims) return percepts_batched = model.predict_percept_batched(implant, stims) percepts_serial = [] for stim in stims: implant.stim = stim percepts_serial.append(model.predict_percept(implant)) npt.assert_equal(len(percepts_serial), len(percepts_batched)) for p1, p2 in zip(percepts_batched, percepts_serial): npt.assert_almost_equal(p1.data, p2.data)
def test_predict_spatial_jax(): # ensure jax predict spatial is equal to normal if not has_jax: pytest.skip("Jax not installed") model1 = BiphasicAxonMapModel(engine='jax', xystep=2) model2 = BiphasicAxonMapModel(engine='cython', xystep=2) model1.build() model2.build() implant = ArgusII() implant.stim = { 'A5': BiphasicPulseTrain(25, 4, 0.45), 'C7': BiphasicPulseTrain(50, 2.5, 0.75) } p1 = model1.predict_percept(implant) p2 = model2.predict_percept(implant) npt.assert_almost_equal(p1.data, p2.data, decimal=4) # test changing model parameters, make sure jax is clearing cache on build model1.axlambda = 800 model2.axlambda = 800 model1.rho = 50 model2.rho = 50 model1.build() model2.build() p1 = model1.predict_percept(implant) p2 = model2.predict_percept(implant) npt.assert_almost_equal(p1.data, p2.data, decimal=4)
def test_BiphasicTripletTrain(amp, interphase_dur, interpulse_dur, delay_dur, cathodic_first): freq = 23.456 stim_dur = 657.456 phase_dur = 2 window_dur = 1000.0 / freq n_pulses = int(freq * stim_dur / 1000.0) mid_first_pulse = delay_dur + phase_dur / 2.0 mid_interphase = delay_dur + phase_dur + interphase_dur / 2.0 mid_second_pulse = delay_dur + interphase_dur + 1.5 * phase_dur mid_interpulse = delay_dur + 2.0*phase_dur + \ interphase_dur + interpulse_dur / 2.0 first_amp = -np.abs(amp) if cathodic_first else np.abs(amp) second_amp = -first_amp # Basic usage: pt = BiphasicTripletTrain(freq, amp, phase_dur, interphase_dur=interphase_dur, interpulse_dur=interpulse_dur, delay_dur=delay_dur, stim_dur=stim_dur, cathodic_first=cathodic_first) for i in range(n_pulses): t_win = i * window_dur npt.assert_almost_equal(pt[0, np.floor(t_win)], 0) npt.assert_almost_equal(pt[0, t_win + mid_first_pulse], first_amp) if interphase_dur > 0: npt.assert_almost_equal(pt[0, t_win + mid_interphase], 0) npt.assert_almost_equal(pt[0, t_win + mid_second_pulse], second_amp) if interpulse_dur > 0: npt.assert_almost_equal(pt[0, mid_interpulse], 0) npt.assert_almost_equal(pt.time[0], 0) npt.assert_almost_equal(pt.time[-1], stim_dur, decimal=2) npt.assert_equal(pt.cathodic_first, cathodic_first) npt.assert_equal(pt.is_charge_balanced, np.isclose(np.trapz(pt.data, pt.time)[0], 0, atol=1e-5)) # Zero frequency: pt = BiphasicPulseTrain(0, amp, phase_dur) npt.assert_almost_equal(pt.time, [0, 1000]) npt.assert_almost_equal(pt.data, 0) # Zero amp: pt = BiphasicPulseTrain(freq, 0, phase_dur) npt.assert_almost_equal(pt.data, 0) # Pulse can fill the entire window (no "unique time points" error): pt = BiphasicTripletTrain(10, 20, 100 / 6.001, stim_dur=500) npt.assert_almost_equal(pt.time[-1], 500) npt.assert_equal(np.round(np.trapz(np.abs(pt.data), pt.time)[0]), 9998) # Specific number of pulses for n_pulses in [2, 4, 5]: pt = BiphasicPulseTrain(500, 30, 0.05, n_pulses=n_pulses, stim_dur=19) npt.assert_almost_equal(np.sum(np.isclose(pt.data, 30)), 2 * n_pulses) npt.assert_almost_equal(pt.time[-1], 19)
def test_Nanduri2012Temporal(): model = Nanduri2012Temporal() # User can set their own params: model.dt = 0.1 npt.assert_equal(model.dt, 0.1) model.build(dt=1e-4) npt.assert_equal(model.dt, 1e-4) # User cannot add more model parameters: with pytest.raises(FreezeError): model.rho = 100 # Nothing in, None out: npt.assert_equal(model.predict_percept(ArgusI().stim), None) # Zero in = zero out: implant = ArgusI(stim=np.zeros((16, 100))) percept = model.predict_percept(implant.stim, t_percept=[0, 1, 2]) npt.assert_equal(isinstance(percept, Percept), True) npt.assert_equal(percept.shape, (16, 1, 3)) npt.assert_almost_equal(percept.data, 0) # Can't request the same time more than once (this would break the Cython # loop, because `idx_frame` is incremented after a write; also doesn't # make much sense): with pytest.raises(ValueError): implant.stim = np.ones((16, 100)) model.predict_percept(implant.stim, t_percept=[0.2, 0.2]) # Brightness scales differently with amplitude vs frequency: model = Nanduri2012Temporal(dt=5e-3) model.build() sdur = 1000.0 # stimulus duration (ms) pdur = 0.45 # (ms) t_percept = np.arange(0, sdur, 5) implant = ProsthesisSystem(ElectrodeArray(DiskElectrode(0, 0, 0, 260))) bright_amp = [] for amp in np.linspace(0, 50, 5): # implant.stim = PulseTrain(model.dt, freq=20, amp=amp, dur=sdur, # pulse_dur=pdur, interphase_dur=pdur) implant.stim = BiphasicPulseTrain(20, amp, pdur, interphase_dur=pdur, stim_dur=sdur) percept = model.predict_percept(implant.stim, t_percept=t_percept) bright_amp.append(percept.data.max()) bright_amp_ref = [0.0, 0.00890, 0.0657, 0.1500, 0.1691] npt.assert_almost_equal(bright_amp, bright_amp_ref, decimal=3) bright_freq = [] for freq in np.linspace(0, 100, 5): # implant.stim = PulseTrain(model.dt, freq=freq, amp=20, dur=sdur, # pulse_dur=pdur, interphase_dur=pdur) implant.stim = BiphasicPulseTrain(freq, 20, pdur, interphase_dur=pdur, stim_dur=sdur) percept = model.predict_percept(implant.stim, t_percept=t_percept) bright_freq.append(percept.data.max()) bright_freq_ref = [0.0, 0.0394, 0.0741, 0.1073, 0.1385] npt.assert_almost_equal(bright_freq, bright_freq_ref, decimal=3)
def test_BiphasicTripletTrain(amp, interphase_dur, delay_dur, cathodic_first): freq = 23.456 stim_dur = 657.456 phase_dur = 2 window_dur = 1000.0 / freq n_pulses = int(freq * stim_dur / 1000.0) dt = 1e-6 mid_first_pulse = delay_dur + phase_dur / 2.0 mid_interphase = delay_dur + phase_dur + interphase_dur / 2.0 mid_second_pulse = delay_dur + interphase_dur + 1.5 * phase_dur first_amp = -np.abs(amp) if cathodic_first else np.abs(amp) second_amp = -first_amp # Basic usage: pt = BiphasicTripletTrain(freq, amp, phase_dur, interphase_dur=interphase_dur, delay_dur=delay_dur, stim_dur=stim_dur, cathodic_first=cathodic_first, dt=dt) for i in range(n_pulses): t_win = i * window_dur npt.assert_almost_equal(pt[0, t_win], 0) npt.assert_almost_equal(pt[0, t_win + mid_first_pulse], first_amp) if interphase_dur > 0: npt.assert_almost_equal(pt[0, t_win + mid_interphase], 0) npt.assert_almost_equal(pt[0, t_win + mid_second_pulse], second_amp) npt.assert_almost_equal(pt.time[0], 0) npt.assert_almost_equal(pt.time[-1], stim_dur, decimal=2) npt.assert_equal(pt.cathodic_first, cathodic_first) npt.assert_equal(pt.charge_balanced, np.isclose(np.trapz(pt.data, pt.time)[0], 0, atol=1e-5)) # Zero frequency: pt = BiphasicPulseTrain(0, amp, phase_dur) npt.assert_almost_equal(pt.time, [0, 1000]) npt.assert_almost_equal(pt.data, 0) # Zero amp: pt = BiphasicPulseTrain(freq, 0, phase_dur) npt.assert_almost_equal(pt.data, 0) # Specific number of pulses for n_pulses in [2, 4, 5]: pt = BiphasicPulseTrain(500, 30, 0.05, n_pulses=n_pulses, stim_dur=19, dt=0.05) npt.assert_almost_equal(np.sum(np.isclose(pt.data, 30)), n_pulses) npt.assert_almost_equal(pt.time[-1], 19)
def test_merge_time_axes_merge_tolerance(): # Test issue where not enough unique points were collected # Leading to interpolation to corrupt stimuli data. # See: https://github.com/pulse2percept/pulse2percept/issues/392 a = BiphasicPulseTrain(20, 1, 0.45) b = BiphasicPulseTrain(30, 1, 0.45) stim = Stimulus({"A2": a, "A10": b}) unique_points = np.unique(stim.data) # Assert no value goes close to 1/3 or -1/3, i.e. a corrupted data point npt.assert_equal(np.isclose(1 / 3, unique_points, atol=0.1).any(), False) npt.assert_equal(np.isclose(-1 / 3, unique_points, atol=0.1).any(), False)
def test_Horsager2009Model(): model = Horsager2009Model() npt.assert_equal(hasattr(model, 'has_space'), True) npt.assert_equal(model.has_space, False) npt.assert_equal(hasattr(model, 'has_time'), True) npt.assert_equal(model.has_time, True) # User can set `dt`: model.temporal.dt = 1e-5 npt.assert_almost_equal(model.dt, 1e-5) npt.assert_almost_equal(model.temporal.dt, 1e-5) model.build(dt=3e-4) npt.assert_almost_equal(model.dt, 3e-4) npt.assert_almost_equal(model.temporal.dt, 3e-4) # User cannot add more model parameters: with pytest.raises(FreezeError): model.rho = 100 # Model and TemporalModel give the same result for amp, freq in zip([136.02, 120.35, 57.71], [5, 15, 225]): stim = BiphasicPulseTrain(freq, amp, 0.075, interphase_dur=0.075, stim_dur=200, cathodic_first=True) model1 = Horsager2009Model().build() model2 = Horsager2009Temporal().build() implant = ProsthesisSystem(PointSource(0, 0, 0), stim=stim) npt.assert_almost_equal(model1.predict_percept(implant).data, model2.predict_percept(stim).data)
def test_metadata(): stim = BiphasicPulseTrain(10, 10, 1, metadata='userdata') npt.assert_equal(stim.metadata['user'], 'userdata') npt.assert_equal(stim.metadata['freq'], 10) npt.assert_equal(stim.metadata['amp'], 10) npt.assert_equal(stim.metadata['phase_dur'], 1) npt.assert_equal(stim.metadata['delay_dur'], 0) stim = Stimulus( { 'A2': BiphasicPulseTrain(10, 10, 1, metadata='userdataA2'), 'B1': BiphasicPulseTrain(11, 9, 2, metadata='userdataB1'), 'C3': BiphasicPulseTrain(12, 8, 3, metadata='userdataC3') }, metadata='stimulus_userdata') npt.assert_equal(stim.metadata['user'], 'stimulus_userdata') npt.assert_equal(stim.metadata['electrodes']['A2']['type'], BiphasicPulseTrain) npt.assert_equal(stim.metadata['electrodes']['B1']['metadata']['freq'], 11) npt.assert_equal(stim.metadata['electrodes']['C3']['metadata']['user'], 'userdataC3')
def test_Horsager2009Temporal(): model = Horsager2009Temporal() # User can set their own params: model.dt = 0.1 npt.assert_equal(model.dt, 0.1) model.build(dt=1e-4) npt.assert_equal(model.dt, 1e-4) # User cannot add more model parameters: with pytest.raises(FreezeError): model.rho = 100 # Nothing in, None out: implant = ProsthesisSystem(PointSource(0, 0, 0)) npt.assert_equal(model.predict_percept(implant.stim), None) # Zero in = zero out: implant.stim = np.zeros((1, 6)) percept = model.predict_percept(implant.stim, t_percept=[0, 1, 2]) npt.assert_equal(isinstance(percept, Percept), True) npt.assert_equal(percept.shape, (1, 1, 3)) npt.assert_almost_equal(percept.data, 0) # Can't request the same time more than once (this would break the Cython # loop, because `idx_frame` is incremented after a write; also doesn't # make much sense): with pytest.raises(ValueError): implant.stim = np.ones((1, 100)) model.predict_percept(implant.stim, t_percept=[0.2, 0.2]) # Single-pulse brightness from Fig.3: model = Horsager2009Temporal().build() for amp, pdur in zip([188.077, 89.74, 10.55], [0.075, 0.15, 4.0]): stim = BiphasicPulse(amp, pdur, interphase_dur=pdur, stim_dur=200, cathodic_first=True) t_percept = np.arange(0, stim.time[-1] + model.dt / 2, model.dt) percept = model.predict_percept(stim, t_percept=t_percept) npt.assert_almost_equal(percept.data.max(), 110.3, decimal=2) # Fixed-duration brightness from Fig.4: model = Horsager2009Temporal().build() for amp, freq in zip([136.02, 120.35, 57.71], [5, 15, 225]): stim = BiphasicPulseTrain(freq, amp, 0.075, interphase_dur=0.075, stim_dur=200, cathodic_first=True) t_percept = np.arange(0, stim.time[-1] + model.dt / 2, model.dt) percept = model.predict_percept(stim, t_percept=t_percept) npt.assert_almost_equal(percept.data.max(), 36.3, decimal=2)
Generating a pulse train ------------------------ The first step is to build a pulse train using the :py:class:`~pulse2percept.stimuli.PulseTrain` class. We want to generate a 20Hz pulse train (0.45ms pulse duration, cathodic-first) at 30uA that lasts for a second: """ # sphinx_gallery_thumbnail_number = 4 from pulse2percept.stimuli import BiphasicPulseTrain, Stimulus tsample = 0.005 # sampling time step (ms) phase_dur = 0.45 # duration of the cathodic/anodic phase (ms) stim_dur = 1000 # stimulus duration (ms) amp_th = 30 # threshold current (uA) stim = BiphasicPulseTrain(20, amp_th, phase_dur, interphase_dur=phase_dur, stim_dur=stim_dur) # Configure Matplotlib: import matplotlib.pyplot as plt plt.style.use('ggplot') from matplotlib import rc rc('font', size=12) # Plot the stimulus in the range t=[0, 60] ms: stim.plot(time=(0, 60)) ############################################################################### # Creating an implant # ------------------- # # Before we can run the Nanduri model, we need to create a retinal implant to
def test_Nanduri2012Model_predict_percept(): # Nothing in = nothing out: model = Nanduri2012Model(xrange=(0, 0), yrange=(0, 0), engine='serial') model.build() implant = ArgusI(stim=None) npt.assert_equal(model.predict_percept(implant), None) implant.stim = np.zeros(16) npt.assert_almost_equal(model.predict_percept(implant).data, 0) # Single-pixel model same as TemporalModel: implant = ProsthesisSystem(DiskElectrode(0, 0, 0, 100)) # implant.stim = PulseTrain(5e-6) implant.stim = BiphasicPulseTrain(20, 20, 0.45, interphase_dur=0.45) t_percept = [0, 0.01, 1.0] percept = model.predict_percept(implant, t_percept=t_percept) temp = Nanduri2012Temporal().build() temp = temp.predict_percept(implant.stim, t_percept=t_percept) npt.assert_almost_equal(percept.data, temp.data, decimal=4) # Only works for DiskElectrode arrays: with pytest.raises(TypeError): implant = ProsthesisSystem(ElectrodeArray(PointSource(0, 0, 0))) implant.stim = 1 model.predict_percept(implant) with pytest.raises(TypeError): implant = ProsthesisSystem( ElectrodeArray( [DiskElectrode(0, 0, 0, 100), PointSource(100, 100, 0)])) implant.stim = [1, 1] model.predict_percept(implant) # Requested times must be multiples of model.dt: implant = ProsthesisSystem(ElectrodeArray(DiskElectrode(0, 0, 0, 260))) # implant.stim = PulseTrain(tsample) implant.stim = BiphasicPulseTrain(20, 20, 0.45) model.temporal.dt = 0.1 with pytest.raises(ValueError): model.predict_percept(implant, t_percept=[0.01]) with pytest.raises(ValueError): model.predict_percept(implant, t_percept=[0.01, 1.0]) with pytest.raises(ValueError): model.predict_percept(implant, t_percept=np.arange(0, 0.5, 0.101)) model.predict_percept(implant, t_percept=np.arange(0, 0.5, 1.0000001)) # Can't request the same time more than once (this would break the Cython # loop, because `idx_frame` is incremented after a write; also doesn't # make much sense): with pytest.raises(ValueError): model.predict_percept(implant, t_percept=[0.2, 0.2]) # It's ok to extrapolate beyond `stim` if the `extrapolate` flag is set: model.temporal.dt = 1e-2 npt.assert_almost_equal( model.predict_percept(implant, t_percept=10000).data, 0) # Output shape must be determined by t_percept: npt.assert_equal( model.predict_percept(implant, t_percept=0).shape, (1, 1, 1)) npt.assert_equal( model.predict_percept(implant, t_percept=[0, 1]).shape, (1, 1, 2)) # Brightness vs. size (use values from Nanduri paper): model = Nanduri2012Model(xystep=0.5, xrange=(-4, 4), yrange=(-4, 4)) model.build() implant = ProsthesisSystem(ElectrodeArray(DiskElectrode(0, 0, 0, 260))) amp_th = 30 bright_th = 0.107 stim_dur = 1000.0 pdur = 0.45 t_percept = np.arange(0, stim_dur, 5) amp_factors = [1, 6] frames_amp = [] for amp_f in amp_factors: implant.stim = BiphasicPulseTrain(20, amp_f * amp_th, pdur, interphase_dur=pdur, stim_dur=stim_dur) percept = model.predict_percept(implant, t_percept=t_percept) idx_frame = np.argmax(np.max(percept.data, axis=(0, 1))) brightest_frame = percept.data[..., idx_frame] frames_amp.append(brightest_frame) npt.assert_equal([np.sum(f > bright_th) for f in frames_amp], [0, 161]) freqs = [20, 120] frames_freq = [] for freq in freqs: implant.stim = BiphasicPulseTrain(freq, 1.25 * amp_th, pdur, interphase_dur=pdur, stim_dur=stim_dur) percept = model.predict_percept(implant, t_percept=t_percept) idx_frame = np.argmax(np.max(percept.data, axis=(0, 1))) brightest_frame = percept.data[..., idx_frame] frames_freq.append(brightest_frame) npt.assert_equal([np.sum(f > bright_th) for f in frames_freq], [21, 49])
# The same procedure can be repeated for # :py:class:`~pulse2percept.stimuli.BiphasicPulseTrain` stimuli to reproduce # Fig. 4. from pulse2percept.stimuli import BiphasicPulseTrain # Load the data: fixed_dur = data[(data.stim_type == 'fixed_duration') & (data.subject == 'S05') & (data.electrode == 'C3') & (data.pulse_dur == 0.075)] # Find the threshold: amp_th = [] for _, row in fixed_dur.iterrows(): stim = BiphasicPulseTrain(row['stim_freq'], 1, row['pulse_dur'], interphase_dur=row['interphase_dur'], stim_dur=row['stim_dur'], cathodic_first=True) amp_th.append( model.find_threshold(stim, row['theta'], amp_range=(0, 300), amp_tol=1, bright_tol=0.1)) plt.semilogx(fixed_dur.stim_freq, fixed_dur.stim_amp, 's', label='data') plt.semilogx(fixed_dur.stim_freq, amp_th, 'k-', linewidth=3, label='model') plt.xticks([5, 15, 75, 225]) plt.xlabel('frequency (Hz)') plt.ylabel('threshold current (uA)') plt.legend()
A series of biphasic pulses can be created with the :py:class:`~pulse2percept.stimuli.BiphasicPulseTrain` class. You have the same options as when setting up a single :py:class:`~pulse2percept.stimuli.BiphasicPulse`, in addition to specifying a pulse train frequency (``freq``) and total stimulus duration (``stim_dur``). For example, a 20 Hz pulse train lasting 200 ms and made from anodic-first biphasic pulses (30 uA, 2 ms pulse duration, no interphase gap) can be created as follows: """ # sphinx_gallery_thumbnail_number = 4 from pulse2percept.stimuli import BiphasicPulseTrain pt = BiphasicPulseTrain(20, 30, 2, stim_dur=200, cathodic_first=False) pt.plot() ############################################################################### # You can also limit the number of pulses in the train, but still make the # stimulus last 200 ms: pt = BiphasicPulseTrain(20, 30, 2, n_pulses=3, stim_dur=200, cathodic_first=False) pt.plot() ###############################################################################
def test_biphasicAxonMapSpatial(engine): # Lambda cannot be too small: with pytest.raises(ValueError): BiphasicAxonMapSpatial(axlambda=9).build() model = BiphasicAxonMapModel(engine=engine, xystep=2).build() # Jax not implemented yet if engine == 'jax': with pytest.raises(NotImplementedError): implant = ArgusII() implant.stim = Stimulus({'A5': BiphasicPulseTrain(20, 1, 0.45)}) percept = model.predict_percept(implant) return # Only accepts biphasic pulse trains with no delay dur implant = ArgusI(stim=np.ones(16)) with pytest.raises(TypeError): model.predict_percept(implant) # Nothing in, None out: npt.assert_equal(model.predict_percept(ArgusI()), None) # Zero in = zero out: implant = ArgusI(stim=np.zeros(16)) percept = model.predict_percept(implant) npt.assert_equal(isinstance(percept, Percept), True) npt.assert_equal(percept.shape, list(model.grid.x.shape) + [1]) npt.assert_almost_equal(percept.data, 0) npt.assert_equal(percept.time, None) # Should be equal to axon map model if effects models return 1 model = BiphasicAxonMapSpatial(engine=engine, xystep=2) def bright_model(freq, amp, pdur): return 1 def size_model(freq, amp, pdur): return 1 def streak_model(freq, amp, pdur): return 1 model.bright_model = bright_model model.size_model = size_model model.streak_model = streak_model model.build() axon_map = AxonMapSpatial(xystep=2).build() implant = ArgusII() implant.stim = Stimulus({'A5': BiphasicPulseTrain(20, 1, 0.45)}) percept = model.predict_percept(implant) percept_axon = axon_map.predict_percept(implant) npt.assert_almost_equal(percept.data[:, :, 0], percept_axon.get_brightest_frame()) # Effect models must be callable model = BiphasicAxonMapSpatial(engine=engine, xystep=2) model.bright_model = 1.0 with pytest.raises(TypeError): model.build() # If t_percept is not specified, there should only be one frame model = BiphasicAxonMapSpatial(engine=engine, xystep=2) model.build() implant = ArgusII() implant.stim = Stimulus({'A5': BiphasicPulseTrain(20, 1, 0.45)}) percept = model.predict_percept(implant) npt.assert_equal(percept.time is None, True) # If t_percept is specified, only first frame should have data # and the rest should be empty percept = model.predict_percept(implant, t_percept=[0, 1, 2, 5, 10]) npt.assert_equal(len(percept.time), 5) npt.assert_equal(np.any(percept.data[:, :, 0]), True) npt.assert_equal(np.any(percept.data[:, :, 1:]), False) # Test that default models give expected values model = BiphasicAxonMapSpatial(engine=engine, rho=400, axlambda=600, xystep=1, xrange=(-20, 20), yrange=(-15, 15)) model.build() implant = ArgusII() implant.stim = Stimulus({'A4': BiphasicPulseTrain(20, 1, 1)}) percept = model.predict_percept(implant) npt.assert_equal(np.sum(percept.data > 0.0813), 81) npt.assert_equal(np.sum(percept.data > 0.1626), 59) npt.assert_equal(np.sum(percept.data > 0.2439), 44) npt.assert_equal(np.sum(percept.data > 0.4065), 26) npt.assert_equal(np.sum(percept.data > 0.5691), 14)
implant.plot() plt.show() ############################################################################## # As mentioned above, the Biphasic Axon Map Model only accepts # :py:class:`~pulse2percept.stimuli.BiphasicPulseTrain` # stimuli with no :py:attr:`~pulse2percept.stimuli.BiphasicPulseTrain.delay_dur`. # The amplitude given to the BiphasicPulseTrain # is interpreted as amplitude as a factor of threshold (i.e. an amp of 1 means # 1xTh) # # You can easily assign BiphasicPulseTrains to electrodes with a dictionary # The following creates a train with 20Hz frequency, 1xTh amplitude, and 0.45ms # pulse / phase duration. implant.stim = {'A4': BiphasicPulseTrain(20, 1, 0.45)} implant.stim.plot() ############################################################################## # Finally, you can predict the percept resulting from stimulation percept = model.predict_percept(implant) ax = percept.plot() ax.set_title('Predicted percept') plt.show() ############################################################################## # Increasing the frequency will make phosphenes brighter fig, axes = plt.subplots(1, 2, sharex=True, sharey=True) implant.stim = {'A4': BiphasicPulseTrain(50, 1, 0.45)} new_percept = model.predict_percept(implant) new_percept.plot(ax=axes[1])