def test_TemporalModel(): # Build grid: model = ValidTemporalModel() npt.assert_equal(model.is_built, False) model.build() npt.assert_equal(model.is_built, True) # Can overwrite default values: model = ValidTemporalModel(dt=2e-5) npt.assert_almost_equal(model.dt, 2e-5) model.build(dt=1.234) npt.assert_almost_equal(model.dt, 1.234) # Cannot add more attributes: with pytest.raises(AttributeError): ValidTemporalModel(newparam=1) with pytest.raises(FreezeError): model.newparam = 1 # Returns Percept object of proper size: npt.assert_equal(model.predict_percept(ArgusI().stim), None) model.dt = 1 for stim in [ np.ones((16, 3)), np.zeros((16, 3)), { 'A1': [1, 2] }, np.ones((16, 2)) ]: implant = ArgusI(stim=stim) percept = model.predict_percept(implant.stim) # By default, percept is output every 20ms. If stimulus is too short, # output at t=[0, 20]. This is mentioned in the docs - for really short # stimuli, users should specify the desired time points manually. n_time = 1 if implant.stim.time is None else 2 npt.assert_equal(percept.shape, (implant.stim.shape[0], 1, n_time)) npt.assert_almost_equal(percept.data, 0) # t_percept is automatically sorted: model.dt = 0.1 percept = model.predict_percept(Stimulus(np.zeros((3, 17))), t_percept=[0.1, 0.8, 0.6]) npt.assert_almost_equal(percept.time, [0.1, 0.6, 0.8]) # Invalid calls: with pytest.raises(ValueError): # Cannot request t_percepts that are not multiples of dt: model.predict_percept(Stimulus(np.ones((3, 9))), t_percept=[0.1, 0.11]) with pytest.raises(ValueError): # Has temporal model but stim.time is None: ValidTemporalModel().predict_percept(Stimulus(3)) with pytest.raises(ValueError): # stim.time==None but requesting t_percept != None ValidTemporalModel().predict_percept(Stimulus(3), t_percept=[0, 1, 2]) with pytest.raises(NotBuiltError): # Must call build first: ValidTemporalModel().predict_percept(Stimulus(3)) with pytest.raises(TypeError): # Must pass a stimulus: ValidTemporalModel().build().predict_percept(ArgusI())
def test_PulseTrain(): # All zeros: npt.assert_almost_equal(PulseTrain(10, Stimulus(np.zeros((1, 5)))).data, 0) # Simple fake pulse: pulse = Stimulus([[0, -1, 0]], time=[0, 0.1, 0.2]) for n_pulses in [2, 3, 10]: pt = PulseTrain(10, pulse, n_pulses=n_pulses, electrode='A4') npt.assert_equal(np.sum(np.isclose(pt.data, -1)), n_pulses) npt.assert_equal(pt.electrodes, 'A4') # PulseTrains can cut off/trim individual pulses if necessary: pt = PulseTrain(3, pulse, stim_dur=11) npt.assert_almost_equal(pt.time[-1], 11) npt.assert_almost_equal(pt[0, 11], 0) # Invalid calls: with pytest.raises(TypeError): # Wrong stimulus type: PulseTrain(10, {'a': 1}) with pytest.raises(ValueError): # Pulse does not fit: PulseTrain(100000, pulse) with pytest.raises(ValueError): # n_pulses does not fit: PulseTrain(10, pulse, n_pulses=100000) with pytest.raises(ValueError): # No time component: PulseTrain(10, Stimulus(1)) with pytest.raises(ValueError): # Empty stim: pulse = Stimulus([[0, 0, 0]], time=[0, 0.1, 0.2], compress=True) PulseTrain(10, pulse)
def test_Model_predict_percept(): # A None Model has nothing to build, nothing to perceive: model = Model() npt.assert_equal(model.predict_percept(ArgusI()), None) npt.assert_equal(model.predict_percept(ArgusI(stim={'A1': 1})), None) npt.assert_equal( model.predict_percept(ArgusI(stim={'A1': 1}), t_percept=[0, 1]), None) # Just the spatial model: model = Model(spatial=ValidSpatialModel()).build() npt.assert_equal(model.predict_percept(ArgusI()), None) # Just the temporal model: model = Model(temporal=ValidTemporalModel()).build() npt.assert_equal(model.predict_percept(ArgusI()), None) # Both spatial and temporal: # Invalid calls: model = Model(spatial=ValidSpatialModel(), temporal=ValidTemporalModel()) with pytest.raises(NotBuiltError): # Must call build first: model.predict_percept(ArgusI()) model.build() with pytest.raises(ValueError): # Cannot request t_percepts that are not multiples of dt: model.predict_percept(ArgusI(stim={'A1': np.ones(16)}), t_percept=[0.1, 0.11]) with pytest.raises(ValueError): # Has temporal model but stim.time is None: ValidTemporalModel().predict_percept(Stimulus(3)) with pytest.raises(ValueError): # stim.time==None but requesting t_percept != None model.predict_percept(ArgusI(stim=np.ones(16)), t_percept=[0, 1, 2]) with pytest.raises(TypeError): # Must pass an implant: model.predict_percept(Stimulus(3))
def test_Stimulus_remove(): stim = Stimulus([[0, 1, 2], [3, 4, 5]], electrodes=['A1', 'C3']) npt.assert_equal('A1' in stim.electrodes, True) npt.assert_equal('C3' in stim.electrodes, True) stim.remove('A1') npt.assert_equal('A1' in stim.electrodes, False) npt.assert_equal('C3' in stim.electrodes, True)
def test_PulseTrain(): # All zeros: npt.assert_almost_equal(PulseTrain(10, Stimulus(np.zeros((1, 5)))).data, 0) # Simple fake pulse: pulse = Stimulus([[0, -1, 0]], time=[0, 0.1, 0.2]) for n_pulses in [2, 3, 10]: pt = PulseTrain(10, pulse, n_pulses=n_pulses) npt.assert_equal(np.sum(np.isclose(pt.data, -1)), n_pulses) # stim_dur too short: npt.assert_almost_equal(PulseTrain(2, pulse, stim_dur=10).data, 0) # Invalid calls: with pytest.raises(TypeError): # Wrong stimulus type: PulseTrain(10, {'a': 1}) with pytest.raises(ValueError): # Pulse does not fit: PulseTrain(100000, pulse) with pytest.raises(ValueError): # n_pulses does not fit: PulseTrain(10, pulse, n_pulses=100000) with pytest.raises(ValueError): # No time component: PulseTrain(10, Stimulus(1)) with pytest.raises(ValueError): # Empty stim: pulse = Stimulus([[0, 0, 0]], time=[0, 0.1, 0.2], compress=True) PulseTrain(10, pulse)
def test_MonophasicPulse(amp, delay_dur): phase_dur = 3.456 # Basic usage: pulse = MonophasicPulse(amp, phase_dur, delay_dur=delay_dur) npt.assert_almost_equal(pulse[0, 0], 0) npt.assert_almost_equal(pulse[0, delay_dur + phase_dur / 2.0], amp) npt.assert_almost_equal(pulse.time[0], 0) npt.assert_almost_equal(pulse.time[-1], phase_dur + delay_dur) npt.assert_equal(pulse.cathodic, amp <= 0) npt.assert_equal(pulse.charge_balanced, False) # Custom stim dur: pulse = MonophasicPulse(amp, phase_dur, delay_dur=delay_dur, stim_dur=100) npt.assert_almost_equal(pulse[0, 0], 0) npt.assert_almost_equal(pulse[0, delay_dur + phase_dur / 2.0], amp) npt.assert_almost_equal(pulse.time[0], 0) npt.assert_almost_equal(pulse.time[-1], 100) # Zero amplitude: pulse = MonophasicPulse(0, phase_dur, delay_dur=delay_dur, electrode='A1') npt.assert_almost_equal(pulse.data, 0) npt.assert_almost_equal(pulse.time[0], 0) npt.assert_almost_equal(pulse.time[-1], phase_dur + delay_dur) npt.assert_equal(pulse.charge_balanced, True) npt.assert_equal(pulse.electrodes, 'A1') # You can wrap a pulse in a Stimulus to overwrite attributes: stim = Stimulus(pulse, electrodes='AA1') npt.assert_equal(stim.electrodes, 'AA1') # Or concatenate: stim = Stimulus([pulse, pulse]) npt.assert_equal(stim.shape[0], 2) npt.assert_almost_equal(stim.data[0, :], stim.data[1, :]) npt.assert_almost_equal(stim.time, pulse.time) npt.assert_equal(stim.electrodes, ['A1', 1]) # Concatenate and rename: stim = Stimulus([pulse, pulse], electrodes=['C1', 'D2']) npt.assert_equal(stim.electrodes, ['C1', 'D2']) # Invalid calls: with pytest.raises(ValueError): MonophasicPulse(amp, 0) with pytest.raises(ValueError): MonophasicPulse(amp, phase_dur, delay_dur=-1) with pytest.raises(ValueError): MonophasicPulse(amp, phase_dur, delay_dur=delay_dur, stim_dur=1) with pytest.raises(ValueError): MonophasicPulse(amp, phase_dur, delay_dur=delay_dur, electrode=['A1', 'B2'])
def test_Stimulus__stim(): stim = Stimulus(3) # User could try and motify the data container after the constructor, which # would lead to inconsistencies between data, electrodes, time. The new # property setting mechanism prevents that. # Requires dict: with pytest.raises(AttributeError): stim._stim = np.array([0, 1]) # Dict must have all required fields: fields = ['data', 'electrodes', 'time'] for field in fields: _fields = deepcopy(fields) _fields.remove(field) with pytest.raises(AttributeError): stim._stim = {f: None for f in _fields} # Data must be a 2-D NumPy array: data = {f: None for f in fields} with pytest.raises(ValueError): data['data'] = np.ones(3) stim._stim = data # Data rows must match electrodes: with pytest.raises(ValueError): data['data'] = np.ones((3, 4)) data['time'] = np.arange(4) data['electrodes'] = np.arange(2) stim._stim = data # Data columns must match time: with pytest.raises(ValueError): data['data'] = np.ones((3, 4)) data['electrodes'] = np.arange(3) data['time'] = np.arange(7) stim._stim = data # Time points must be unique: assert_warns_msg(UserWarning, _unique_timepoints, None, stim, data) # But if you do all the things right, you can reset the stimulus by hand: data['data'] = np.ones((3, 1)) data['electrodes'] = np.arange(3) data['time'] = None stim._stim = data data['data'] = np.ones((3, 1)) data['electrodes'] = np.arange(3) data['time'] = np.arange(1) stim._stim = data data['data'] = np.ones((3, 4)) data['electrodes'] = np.arange(3) data['time'] = np.array([0, 1, 1 + DT, 2]) stim._stim = data
def test_ProsthesisSystem(): # Invalid instantiations: with pytest.raises(ValueError): ProsthesisSystem(ElectrodeArray(PointSource(0, 0, 0)), eye='both') # Iterating over the electrode array: implant = ProsthesisSystem(PointSource(0, 0, 0)) npt.assert_equal(implant.n_electrodes, 1) npt.assert_equal(implant[0], implant.earray[0]) npt.assert_equal(implant.keys(), implant.earray.keys()) # Set a stimulus after the constructor: npt.assert_equal(implant.stim, None) implant.stim = 3 npt.assert_equal(isinstance(implant.stim, Stimulus), True) npt.assert_equal(implant.stim.shape, (1, 1)) npt.assert_equal(implant.stim.time, None) npt.assert_equal(implant.stim.electrodes, [0]) with pytest.raises(ValueError): # Wrong number of stimuli implant.stim = [1, 2] with pytest.raises(TypeError): # Invalid stim type: implant.stim = "stim" # Invalid electrode names: with pytest.raises(ValueError): implant.stim = {'A1': 1} with pytest.raises(ValueError): implant.stim = Stimulus({'A1': 1}) # Slots: npt.assert_equal(hasattr(implant, '__slots__'), True) npt.assert_equal(hasattr(implant, '__dict__'), False)
def test_predict_batched(engine): if not has_jax: pytest.skip("Jax not installed") # Allows mix of valid Stimulus types stims = [{ 'A5': BiphasicPulseTrain(25, 4, 0.45), 'C7': BiphasicPulseTrain(50, 2.5, 0.75) }, { 'B4': BiphasicPulseTrain(3, 1, 0.32) }, Stimulus({'F3': BiphasicPulseTrain(12, 3, 1.2)})] implant = ArgusII() model = BiphasicAxonMapModel(engine=engine, xystep=2) model.build() # Import error if we dont have jax if engine != 'jax': with pytest.raises(ImportError): model.predict_percept_batched(implant, stims) return percepts_batched = model.predict_percept_batched(implant, stims) percepts_serial = [] for stim in stims: implant.stim = stim percepts_serial.append(model.predict_percept(implant)) npt.assert_equal(len(percepts_serial), len(percepts_batched)) for p1, p2 in zip(percepts_batched, percepts_serial): npt.assert_almost_equal(p1.data, p2.data)
def test_Stimulus_merge(): # We can stack multiple stimuli together - their time axes will be merged: stim1 = Stimulus([[0, 1, 2, 3, 4]], time=[0, 1, 2, 3, 4]) stim2 = Stimulus([[0, 1, 2]], time=[-0.5, 1.5, 4.5]) merge = Stimulus([stim1, stim2]) npt.assert_almost_equal(merge.time, np.unique(np.hstack((stim1.time, stim2.time))), decimal=6) npt.assert_almost_equal(merge[0, [0, -1]], stim1[0, [0, -1]]) npt.assert_almost_equal(merge[1, [0, -1]], stim2[0, [0, -1]]) # We can keep stacking - even when nested: stim3 = Stimulus([[14]], time=[9.7]) merge2 = Stimulus([merge, stim3]) npt.assert_almost_equal(merge2.time, np.unique((np.hstack( (stim1.time, stim2.time, stim3.time)))), decimal=6) npt.assert_almost_equal(merge2[0, [0, -1]], stim1[0, [0, -1]]) npt.assert_almost_equal(merge2[1, [0, -1]], stim2[0, [0, -1]]) npt.assert_almost_equal(merge2[2, [0, -1]], stim3[0, [0, -1]])
def test_merge_time_axes_merge_tolerance(): # Test issue where not enough unique points were collected # Leading to interpolation to corrupt stimuli data. # See: https://github.com/pulse2percept/pulse2percept/issues/392 a = BiphasicPulseTrain(20, 1, 0.45) b = BiphasicPulseTrain(30, 1, 0.45) stim = Stimulus({"A2": a, "A10": b}) unique_points = np.unique(stim.data) # Assert no value goes close to 1/3 or -1/3, i.e. a corrupted data point npt.assert_equal(np.isclose(1 / 3, unique_points, atol=0.1).any(), False) npt.assert_equal(np.isclose(-1 / 3, unique_points, atol=0.1).any(), False)
def test_Stimulus_arithmetic(scalar): stim = Stimulus([[0, 21, -13, 0, 0]], time=[0, 1, 2, 3, 4]) npt.assert_almost_equal((stim + scalar).data, stim.data + scalar, decimal=5) npt.assert_almost_equal((scalar + stim).data, scalar + stim.data, decimal=5) npt.assert_almost_equal((stim - scalar).data, stim.data - scalar, decimal=5) npt.assert_almost_equal((scalar - stim).data, scalar - stim.data, decimal=5) npt.assert_almost_equal((stim * scalar).data, stim.data * scalar, decimal=5) npt.assert_almost_equal((scalar * stim).data, scalar * stim.data, decimal=5) npt.assert_almost_equal((stim / scalar).data, stim.data / scalar, decimal=5) npt.assert_almost_equal((-stim).data, -1 * stim.data, decimal=5) npt.assert_almost_equal((stim >> scalar).time, stim.time + scalar, decimal=5) npt.assert_almost_equal((stim << scalar).time, stim.time - scalar, decimal=5) # 10 / stim is not supported because it will always give a division by # zero error: with pytest.raises(TypeError): s = scalar / stim with pytest.raises(TypeError): s = stim + stim with pytest.raises(TypeError): s = stim - stim with pytest.raises(TypeError): s = stim * stim with pytest.raises(TypeError): s = stim / stim with pytest.raises(TypeError): s = stim + [1, 1] with pytest.raises(TypeError): s = stim * np.array([2, 3]) with pytest.raises(TypeError): s = stim >> np.array([2, 3]) with pytest.raises(TypeError): s = stim << np.array([2, 3])
def test_ProsthesisSystem(): # Invalid instantiations: with pytest.raises(ValueError): ProsthesisSystem(ElectrodeArray(PointSource(0, 0, 0)), eye='both') with pytest.raises(TypeError): ProsthesisSystem(Stimulus) # Iterating over the electrode array: implant = ProsthesisSystem(PointSource(0, 0, 0)) npt.assert_equal(implant.n_electrodes, 1) npt.assert_equal(implant[0], implant.earray[0]) npt.assert_equal(implant.electrode_names, implant.earray.electrode_names) for i, e in zip(implant, implant.earray): npt.assert_equal(i, e) # Set a stimulus after the constructor: npt.assert_equal(implant.stim, None) implant.stim = 3 npt.assert_equal(isinstance(implant.stim, Stimulus), True) npt.assert_equal(implant.stim.shape, (1, 1)) npt.assert_equal(implant.stim.time, None) npt.assert_equal(implant.stim.electrodes, [0]) ax = implant.plot() npt.assert_equal(len(ax.texts), 0) npt.assert_equal(len(ax.collections), 1) with pytest.raises(ValueError): # Wrong number of stimuli implant.stim = [1, 2] with pytest.raises(TypeError): # Invalid stim type: implant.stim = "stim" # Invalid electrode names: with pytest.raises(ValueError): implant.stim = {'A1': 1} with pytest.raises(ValueError): implant.stim = Stimulus({'A1': 1}) # Safe mode requires charge-balanced pulses: with pytest.raises(ValueError): implant = ProsthesisSystem(PointSource(0, 0, 0), safe_mode=True) implant.stim = 1 # Slots: npt.assert_equal(hasattr(implant, '__slots__'), True) npt.assert_equal(hasattr(implant, '__dict__'), False)
def test_SpatialModel(): # Build grid: model = ValidSpatialModel() npt.assert_equal(model.grid, None) npt.assert_equal(model.is_built, False) model.build() npt.assert_equal(model.is_built, True) npt.assert_equal(isinstance(model.grid, GridXY), True) npt.assert_equal(isinstance(model.grid.xret, np.ndarray), True) # Can overwrite default values: model = ValidSpatialModel(xystep=1.234) npt.assert_almost_equal(model.xystep, 1.234) model.build(xystep=2.345) npt.assert_almost_equal(model.xystep, 2.345) # Cannot add more attributes: with pytest.raises(AttributeError): ValidSpatialModel(newparam=1) with pytest.raises(FreezeError): model.newparam = 1 # Returns Percept object of proper size: npt.assert_equal(model.predict_percept(ArgusI()), None) for stim in [np.ones(16), np.zeros(16), {'A1': 2}, np.ones((16, 2))]: implant = ArgusI(stim=stim) percept = model.predict_percept(implant) npt.assert_equal(isinstance(percept, Percept), True) n_time = 1 if implant.stim.time is None else len(implant.stim.time) npt.assert_equal( percept.shape, (model.grid.x.shape[0], model.grid.x.shape[1], n_time)) npt.assert_almost_equal(percept.data, 0) # Invalid calls: with pytest.raises(ValueError): # stim.time==None but requesting t_percept != None implant.stim = np.ones(16) model.predict_percept(implant, t_percept=[0, 1, 2]) with pytest.raises(NotBuiltError): # must call build first model = ValidSpatialModel() model.predict_percept(ArgusI()) with pytest.raises(TypeError): # must pass an implant ValidSpatialModel().build().predict_percept(Stimulus(3))
def test_ProsthesisSystem_stim(): implant = ProsthesisSystem(ElectrodeGrid((13, 13), 20)) stim = Stimulus(np.ones((13 * 13 + 1, 5))) with pytest.raises(ValueError): implant.stim = stim # Deactivated electrodes cannot receive stimuli: implant.deactivate('H4') npt.assert_equal(implant['H4'].activated, False) implant.stim = {'H4': 1} npt.assert_equal('H4' in implant.stim.electrodes, False) implant.deactivate('all') npt.assert_equal(not implant.stim.data, True) implant.activate('all') implant.stim = {'H4': 1} npt.assert_equal('H4' in implant.stim.electrodes, True)
def test_FadingTemporal(): model = FadingTemporal() # User can set their own params: model.dt = 0.1 npt.assert_equal(model.dt, 0.1) model.build(dt=1e-4) npt.assert_equal(model.dt, 1e-4) # User cannot add more model parameters: with pytest.raises(FreezeError): model.rho = 100 # Nothing in, None out: npt.assert_equal(model.predict_percept(None), None) # Zero in = zero out: stim = BiphasicPulse(0, 1) percept = model.predict_percept(stim, t_percept=[0, 1, 2]) npt.assert_equal(isinstance(percept, Percept), True) npt.assert_equal(percept.shape, (1, 1, 3)) npt.assert_almost_equal(percept.data, 0) # Can't request the same time more than once (this would break the Cython # loop, because `idx_frame` is incremented after a write; also doesn't # make much sense): with pytest.raises(ValueError): stim = Stimulus(np.ones((1, 100))) model.predict_percept(stim, t_percept=[0.2, 0.2]) # Simple decay for single cathodic pulse: model = FadingTemporal(tau=1).build() stim = MonophasicPulse(-1, 1, stim_dur=10) percept = model.predict_percept(stim, np.arange(stim.duration)) npt.assert_almost_equal(percept.data.ravel()[:3], [0, 0.633, 0.232], decimal=3) npt.assert_almost_equal(percept.data.ravel()[-1], 0, decimal=3) # But all zeros for anodic pulse: stim = MonophasicPulse(1, 1, stim_dur=10) percept = model.predict_percept(stim, np.arange(stim.duration)) npt.assert_almost_equal(percept.data, 0) # tau cannot be negative: with pytest.raises(ValueError): FadingTemporal(tau=-1).build()
def test_metadata(): stim = BiphasicPulseTrain(10, 10, 1, metadata='userdata') npt.assert_equal(stim.metadata['user'], 'userdata') npt.assert_equal(stim.metadata['freq'], 10) npt.assert_equal(stim.metadata['amp'], 10) npt.assert_equal(stim.metadata['phase_dur'], 1) npt.assert_equal(stim.metadata['delay_dur'], 0) stim = Stimulus( { 'A2': BiphasicPulseTrain(10, 10, 1, metadata='userdataA2'), 'B1': BiphasicPulseTrain(11, 9, 2, metadata='userdataB1'), 'C3': BiphasicPulseTrain(12, 8, 3, metadata='userdataC3') }, metadata='stimulus_userdata') npt.assert_equal(stim.metadata['user'], 'stimulus_userdata') npt.assert_equal(stim.metadata['electrodes']['A2']['type'], BiphasicPulseTrain) npt.assert_equal(stim.metadata['electrodes']['B1']['metadata']['freq'], 11) npt.assert_equal(stim.metadata['electrodes']['C3']['metadata']['user'], 'userdataC3')
def test_ProsthesisSystem_stim(): implant = ProsthesisSystem(ElectrodeGrid((13, 13), 20)) stim = Stimulus(np.ones((13 * 13 + 1, 5))) with pytest.raises(ValueError): implant.stim = stim # color mapping stim = np.zeros((13 * 13, 5)) stim[84, 0] = 1 stim[98, 2] = 2 implant.stim = stim plt.cla() ax = implant.plot(stim_cmap='hsv') plt.colorbar() npt.assert_equal(len(ax.collections), 1) npt.assert_equal(ax.collections[0].colorbar.vmax, 2) npt.assert_equal(ax.collections[0].cmap(ax.collections[0].norm(1)), (0.0, 1.0, 0.9647031631761764, 1)) # make sure default behaviour unchanged plt.cla() ax = implant.plot() plt.colorbar() npt.assert_equal(len(ax.collections), 1) npt.assert_equal(ax.collections[0].colorbar.vmax, 1) npt.assert_equal(ax.collections[0].cmap(ax.collections[0].norm(1)), (0.993248, 0.906157, 0.143936, 1)) # Deactivated electrodes cannot receive stimuli: implant.deactivate('H4') npt.assert_equal(implant['H4'].activated, False) implant.stim = {'H4': 1} npt.assert_equal('H4' in implant.stim.electrodes, False) implant.deactivate('all') npt.assert_equal(not implant.stim.data, True) implant.activate('all') implant.stim = {'H4': 1} npt.assert_equal('H4' in implant.stim.electrodes, True)
def test_biphasicAxonMapSpatial(engine): # Lambda cannot be too small: with pytest.raises(ValueError): BiphasicAxonMapSpatial(axlambda=9).build() model = BiphasicAxonMapModel(engine=engine, xystep=2).build() # Jax not implemented yet if engine == 'jax': with pytest.raises(NotImplementedError): implant = ArgusII() implant.stim = Stimulus({'A5': BiphasicPulseTrain(20, 1, 0.45)}) percept = model.predict_percept(implant) return # Only accepts biphasic pulse trains with no delay dur implant = ArgusI(stim=np.ones(16)) with pytest.raises(TypeError): model.predict_percept(implant) # Nothing in, None out: npt.assert_equal(model.predict_percept(ArgusI()), None) # Zero in = zero out: implant = ArgusI(stim=np.zeros(16)) percept = model.predict_percept(implant) npt.assert_equal(isinstance(percept, Percept), True) npt.assert_equal(percept.shape, list(model.grid.x.shape) + [1]) npt.assert_almost_equal(percept.data, 0) npt.assert_equal(percept.time, None) # Should be equal to axon map model if effects models return 1 model = BiphasicAxonMapSpatial(engine=engine, xystep=2) def bright_model(freq, amp, pdur): return 1 def size_model(freq, amp, pdur): return 1 def streak_model(freq, amp, pdur): return 1 model.bright_model = bright_model model.size_model = size_model model.streak_model = streak_model model.build() axon_map = AxonMapSpatial(xystep=2).build() implant = ArgusII() implant.stim = Stimulus({'A5': BiphasicPulseTrain(20, 1, 0.45)}) percept = model.predict_percept(implant) percept_axon = axon_map.predict_percept(implant) npt.assert_almost_equal(percept.data[:, :, 0], percept_axon.get_brightest_frame()) # Effect models must be callable model = BiphasicAxonMapSpatial(engine=engine, xystep=2) model.bright_model = 1.0 with pytest.raises(TypeError): model.build() # If t_percept is not specified, there should only be one frame model = BiphasicAxonMapSpatial(engine=engine, xystep=2) model.build() implant = ArgusII() implant.stim = Stimulus({'A5': BiphasicPulseTrain(20, 1, 0.45)}) percept = model.predict_percept(implant) npt.assert_equal(percept.time is None, True) # If t_percept is specified, only first frame should have data # and the rest should be empty percept = model.predict_percept(implant, t_percept=[0, 1, 2, 5, 10]) npt.assert_equal(len(percept.time), 5) npt.assert_equal(np.any(percept.data[:, :, 0]), True) npt.assert_equal(np.any(percept.data[:, :, 1:]), False) # Test that default models give expected values model = BiphasicAxonMapSpatial(engine=engine, rho=400, axlambda=600, xystep=1, xrange=(-20, 20), yrange=(-15, 15)) model.build() implant = ArgusII() implant.stim = Stimulus({'A4': BiphasicPulseTrain(20, 1, 1)}) percept = model.predict_percept(implant) npt.assert_equal(np.sum(percept.data > 0.0813), 81) npt.assert_equal(np.sum(percept.data > 0.1626), 59) npt.assert_equal(np.sum(percept.data > 0.2439), 44) npt.assert_equal(np.sum(percept.data > 0.4065), 26) npt.assert_equal(np.sum(percept.data > 0.5691), 14)
pt.plot() ############################################################################### # Generic pulse trains # -------------------- # # Finally, you can concatenate any :py:class:`~pulse2percept.stimuli.Stimulus` # object into a pulse train. # # For example, let's define a single ramp stimulus: import numpy as np from pulse2percept.stimuli import Stimulus, PulseTrain # Single ramp: dt = 1e-3 ramp = Stimulus([[0, 0, 1, 1, 2, 2, 0, 0]], time=[0, 1, 1 + dt, 2, 2 + dt, 3, 3 + dt, 5 - dt]) ramp.plot() # Ramp train: PulseTrain(20, ramp, stim_dur=200).plot() # Biphasic ramp: biphasic_ramp = Stimulus(np.concatenate((ramp.data, -ramp.data), axis=1), time=np.concatenate((ramp.time, ramp.time + 5))) biphasic_ramp.plot() # Biphasic ramp train: PulseTrain(20, biphasic_ramp, stim_dur=200).plot()
delay_dur = 12 # Vary this current to determine threshold: amp_test = 45 # Cathodic phase of the test pulse (delivered after a delay): cath_test = MonophasicPulse(-amp_test, phase_dur, delay_dur=delay_dur) ############################################################################### # The anodic phase were always presented 20 ms after the second cathodic phase: anod_standard = MonophasicPulse(0.5 * amp_th, phase_dur, delay_dur=20) anod_test = MonophasicPulse(amp_test, phase_dur, delay_dur=delay_dur) ############################################################################### # The last step is to concatenate all the pulses into a single stimulus: from pulse2percept.stimuli import Stimulus data = [] time = [] time_tracker = 0 for pulse in (cath_standard, cath_test, anod_standard, anod_test): data.append(pulse.data) time.append(pulse.time + time_tracker) time_tracker += pulse.time[-1] latent_add = Stimulus(np.concatenate(data, axis=1), time=np.concatenate(time)) latent_add.plot()
def test_AsymmetricBiphasicPulse(amp1, amp2, interphase_dur, delay_dur, cathodic_first): phase_dur1 = 2.1 phase_dur2 = 4.87 mid_first_pulse = delay_dur + phase_dur1 / 2.0 mid_interphase = delay_dur + phase_dur1 + interphase_dur / 2.0 mid_second_pulse = delay_dur + phase_dur1 + interphase_dur + phase_dur2 / 2 first_amp = -np.abs(amp1) if cathodic_first else np.abs(amp1) second_amp = np.abs(amp2) if cathodic_first else -np.abs(amp2) min_dur = delay_dur + phase_dur1 + interphase_dur + phase_dur2 # Basic usage: pulse = AsymmetricBiphasicPulse(amp1, amp2, phase_dur1, phase_dur2, interphase_dur=interphase_dur, delay_dur=delay_dur, cathodic_first=cathodic_first) npt.assert_almost_equal(pulse[0, 0], 0) npt.assert_almost_equal(pulse[0, mid_first_pulse], first_amp) npt.assert_almost_equal(pulse[0, mid_interphase], 0) npt.assert_almost_equal(pulse[0, mid_second_pulse], second_amp) npt.assert_almost_equal(pulse.time[0], 0) npt.assert_almost_equal(pulse.time[-1], min_dur, decimal=3) npt.assert_equal(pulse.cathodic_first, cathodic_first) npt.assert_equal(pulse.is_charge_balanced, np.isclose(np.trapz(pulse.data, pulse.time)[0], 0)) # Custom stim dur: pulse = AsymmetricBiphasicPulse(amp1, amp2, phase_dur1, phase_dur2, interphase_dur=interphase_dur, delay_dur=delay_dur, cathodic_first=cathodic_first, stim_dur=100, electrode='A1') npt.assert_almost_equal(pulse[0, 0], 0) npt.assert_almost_equal(pulse[0, mid_first_pulse], first_amp) npt.assert_almost_equal(pulse[0, mid_interphase], 0) npt.assert_almost_equal(pulse[0, mid_second_pulse], second_amp) npt.assert_almost_equal(pulse.time[0], 0) npt.assert_almost_equal(pulse.time[-1], 100) npt.assert_equal(pulse.electrodes, 'A1') # Exact stim dur: stim_dur = delay_dur + phase_dur1 + interphase_dur + phase_dur2 pulse = AsymmetricBiphasicPulse(amp1, amp2, phase_dur1, phase_dur2, interphase_dur=interphase_dur, delay_dur=delay_dur, cathodic_first=cathodic_first, stim_dur=stim_dur, electrode='A1') npt.assert_almost_equal(pulse.time[0], 0) npt.assert_almost_equal(pulse.time[-1], stim_dur, decimal=6) # Zero amplitude: pulse = AsymmetricBiphasicPulse(0, 0, phase_dur1, phase_dur2, interphase_dur=interphase_dur, delay_dur=delay_dur, cathodic_first=cathodic_first) npt.assert_almost_equal(pulse.data, 0) npt.assert_almost_equal(pulse.time[0], 0) npt.assert_almost_equal(pulse.time[-1], min_dur, decimal=3) npt.assert_equal(pulse.is_charge_balanced, np.isclose(np.trapz(pulse.data, pulse.time)[0], 0)) # If both phases have the same values, it's basically a symmetric biphasic # pulse: abp = AsymmetricBiphasicPulse(amp1, amp1, phase_dur1, phase_dur1, interphase_dur=interphase_dur, delay_dur=delay_dur, cathodic_first=cathodic_first) bp = BiphasicPulse(amp1, phase_dur1, interphase_dur=interphase_dur, delay_dur=delay_dur, cathodic_first=cathodic_first) bp_min_dur = phase_dur1 * 2 + interphase_dur + delay_dur npt.assert_almost_equal(abp[:, np.linspace(0, bp_min_dur, num=5)], bp[:, np.linspace(0, bp_min_dur, num=5)]) npt.assert_equal(abp.cathodic_first, bp.cathodic_first) # If one phase is zero, it's basically a monophasic pulse: abp = AsymmetricBiphasicPulse(amp1, 0, phase_dur1, phase_dur2, interphase_dur=interphase_dur, delay_dur=delay_dur, cathodic_first=cathodic_first) mono = MonophasicPulse(first_amp, phase_dur1, delay_dur=delay_dur, stim_dur=min_dur) npt.assert_almost_equal(abp[:, np.linspace(0, min_dur, num=5)], mono[:, np.linspace(0, min_dur, num=5)]) npt.assert_equal(abp.cathodic_first, mono.cathodic) # You can wrap a pulse in a Stimulus to overwrite attributes: stim = Stimulus(pulse, electrodes='AA1') npt.assert_equal(stim.electrodes, 'AA1') # Or concatenate: stim = Stimulus([pulse, pulse]) npt.assert_equal(stim.shape[0], 2) npt.assert_almost_equal(stim.data[0, :], stim.data[1, :]) npt.assert_almost_equal(stim.time, pulse.time, decimal=2) npt.assert_equal(stim.electrodes, [0, 1]) # Concatenate and rename: stim = Stimulus([pulse, pulse], electrodes=['C1', 'D2']) npt.assert_equal(stim.electrodes, ['C1', 'D2']) # Invalid calls: with pytest.raises(ValueError): AsymmetricBiphasicPulse(amp1, amp2, 0, phase_dur2) with pytest.raises(ValueError): AsymmetricBiphasicPulse(amp1, amp2, phase_dur1, 0) with pytest.raises(ValueError): AsymmetricBiphasicPulse(amp1, amp2, phase_dur1, phase_dur2, interphase_dur=-1) with pytest.raises(ValueError): AsymmetricBiphasicPulse(amp1, amp2, phase_dur1, phase_dur2, interphase_dur=interphase_dur, delay_dur=-1) with pytest.raises(ValueError): AsymmetricBiphasicPulse(amp1, amp2, phase_dur1, phase_dur2, interphase_dur=interphase_dur, delay_dur=delay_dur, stim_dur=1) with pytest.raises(ValueError): AsymmetricBiphasicPulse(amp1, amp2, phase_dur1, phase_dur2, interphase_dur=interphase_dur, delay_dur=delay_dur, electrode=['A1', 'B2'])
def test_BiphasicPulse(amp, interphase_dur, delay_dur, cathodic_first): phase_dur = 3.19 mid_first_pulse = delay_dur + phase_dur / 2.0 mid_interphase = delay_dur + phase_dur + interphase_dur / 2.0 mid_second_pulse = delay_dur + interphase_dur + 1.5 * phase_dur first_amp = -np.abs(amp) if cathodic_first else np.abs(amp) second_amp = -first_amp min_dur = 2 * phase_dur + delay_dur + interphase_dur # Basic usage: pulse = BiphasicPulse(amp, phase_dur, interphase_dur=interphase_dur, delay_dur=delay_dur, cathodic_first=cathodic_first) npt.assert_almost_equal(pulse[0, 0], 0) npt.assert_almost_equal(pulse[0, mid_first_pulse], first_amp) npt.assert_almost_equal(pulse[0, mid_interphase], 0) npt.assert_almost_equal(pulse[0, mid_second_pulse], second_amp) npt.assert_almost_equal(pulse.time[0], 0) npt.assert_almost_equal(pulse.time[-1], min_dur, decimal=3) npt.assert_equal(pulse.cathodic_first, cathodic_first) npt.assert_equal(pulse.is_charge_balanced, True) # Custom stim dur: pulse = BiphasicPulse(amp, phase_dur, interphase_dur=interphase_dur, delay_dur=delay_dur, cathodic_first=cathodic_first, stim_dur=100, electrode='B1') npt.assert_almost_equal(pulse[0, 0], 0) npt.assert_almost_equal(pulse[0, mid_first_pulse], first_amp) npt.assert_almost_equal(pulse[0, mid_interphase], 0) npt.assert_almost_equal(pulse[0, mid_second_pulse], second_amp) npt.assert_almost_equal(pulse.time[0], 0) npt.assert_almost_equal(pulse.time[-1], 100) npt.assert_equal(pulse.electrodes, 'B1') # Exact stim dur: stim_dur = 2 * phase_dur + interphase_dur + delay_dur pulse = BiphasicPulse(amp, phase_dur, interphase_dur=interphase_dur, delay_dur=delay_dur, cathodic_first=cathodic_first, stim_dur=stim_dur) npt.assert_almost_equal(pulse.time[0], 0) npt.assert_almost_equal(pulse.time[-1], stim_dur, decimal=6) # Zero amplitude: pulse = BiphasicPulse(0, phase_dur, interphase_dur=interphase_dur, delay_dur=delay_dur, cathodic_first=cathodic_first) npt.assert_almost_equal(pulse.data, 0) npt.assert_almost_equal(pulse.time[0], 0) npt.assert_almost_equal(pulse.time[-1], min_dur, decimal=3) npt.assert_equal(pulse.is_charge_balanced, True) # You can wrap a pulse in a Stimulus to overwrite attributes: stim = Stimulus(pulse, electrodes='AA1') npt.assert_equal(stim.electrodes, ['AA1']) # Or concatenate: stim = Stimulus([pulse, pulse]) npt.assert_equal(stim.shape[0], 2) npt.assert_almost_equal(stim.data[0, :], stim.data[1, :]) npt.assert_almost_equal(stim.time, pulse.time) npt.assert_equal(stim.electrodes, [0, 1]) # Concatenate and rename: stim = Stimulus([pulse, pulse], electrodes=['C1', 'D2']) npt.assert_equal(stim.electrodes, ['C1', 'D2']) # Floating point math with np.unique is tricky, but this works: BiphasicPulse(10, np.pi, interphase_dur=np.pi, delay_dur=np.pi, stim_dur=5 * np.pi) # Invalid calls: with pytest.raises(ValueError): BiphasicPulse(amp, 0) with pytest.raises(ValueError): BiphasicPulse(amp, phase_dur, interphase_dur=-1) with pytest.raises(ValueError): BiphasicPulse(amp, phase_dur, interphase_dur=interphase_dur, delay_dur=-1) with pytest.raises(ValueError): BiphasicPulse(amp, phase_dur, interphase_dur=interphase_dur, delay_dur=delay_dur, stim_dur=1) with pytest.raises(ValueError): BiphasicPulse(amp, phase_dur, interphase_dur=interphase_dur, delay_dur=delay_dur, electrode=['A1', 'B2'])
def test_Stimulus_plot(): # Stimulus with one electrode stim = Stimulus([[0, -10, 10, -10, 10, -10, 0]], time=[0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]) for time in [None, Ellipsis, slice(None)]: # Different ways to plot all data points: fig, ax = plt.subplots() stim.plot(time=time, ax=ax) npt.assert_equal(isinstance(ax, Subplot), True) npt.assert_almost_equal( ax.get_yticks(), [stim.data.min(), 0, stim.data.max()]) npt.assert_equal(len(ax.lines), 1) npt.assert_almost_equal(ax.lines[0].get_data()[1].min(), stim.data.min()) npt.assert_almost_equal(ax.lines[0].get_data()[1].max(), stim.data.max()) plt.close(fig) # Plot a range of time values (times are sliced, but end points are # interpolated): fig, ax = plt.subplots() ax = stim.plot(time=(0.2, 0.6), ax=ax) npt.assert_equal(isinstance(ax, Subplot), True) npt.assert_equal(len(ax.lines), 1) t_vals = ax.lines[0].get_data()[0] npt.assert_almost_equal(t_vals[0], 0.2) npt.assert_almost_equal(t_vals[-1], 0.6) plt.close(fig) # Plot exact time points: t_vals = [0.2, 0.3, 0.4] fig, ax = plt.subplots() stim.plot(time=t_vals, ax=ax) npt.assert_equal(isinstance(ax, Subplot), True) npt.assert_equal(len(ax.lines), 1) npt.assert_almost_equal(ax.lines[0].get_data()[0], t_vals) npt.assert_almost_equal(ax.lines[0].get_data()[1], np.squeeze(stim[:, t_vals])) # Plot multiple electrodes with string names: for n_electrodes in [2, 3, 4]: stim = Stimulus(np.random.rand(n_electrodes, 20), electrodes=[f'E{i}' for i in range(n_electrodes)]) fig, axes = plt.subplots(ncols=n_electrodes) stim.plot(ax=axes) npt.assert_equal(isinstance(axes, (list, np.ndarray)), True) for ax, electrode in zip(axes, stim.electrodes): npt.assert_equal(isinstance(ax, Subplot), True) npt.assert_equal(len(ax.lines), 1) npt.assert_equal(ax.get_ylabel(), electrode) npt.assert_almost_equal(ax.lines[0].get_data()[0], stim.time) npt.assert_almost_equal(ax.lines[0].get_data()[1], stim[electrode, :]) plt.close(fig) # Invalid calls: with pytest.raises(TypeError): stim.plot(electrodes=1.2) with pytest.raises(TypeError): stim.plot(time=0) with pytest.raises(TypeError): stim.plot(ax='as') with pytest.raises(TypeError): stim.plot(time='0 0.1') with pytest.raises(NotImplementedError): Stimulus(np.ones(10)).plot() with pytest.raises(ValueError): stim = Stimulus(np.ones((3, 10))) _, axes = plt.subplots(nrows=4) stim.plot(ax=axes) with pytest.raises(TypeError): stim = Stimulus(np.ones((3, 10))) _, axes = plt.subplots(nrows=3) axes[1] = 0 stim.plot(ax=axes)
def test_Stimulus_append(): # Basic usage: stim = Stimulus([[0, 1, 0]], time=[0, 1, 2]) stim2 = Stimulus([[0, 2]], time=[0, 0.5]) comb = stim.append(stim2) # End point of stim and starting point of stim2 will be merged: npt.assert_almost_equal(comb.data, [[0, 1, 0, 2]]) npt.assert_almost_equal(comb.time, [0, 1, 2, 2.5]) # When other stimulus is shifted: comb = stim.append(stim2 >> 10) npt.assert_almost_equal(comb.time, [0, 1, 2, 12, 12.5]) with pytest.raises(TypeError): # 'other' must be Stimulus: stim.append(np.array([[0, 1, 2]])) with pytest.raises(ValueError): # other cannot have time=None: stim.append(Stimulus(3)) with pytest.raises(ValueError): # self cannot have time=None: Stimulus(3).append(stim) with pytest.raises(ValueError): stim.append(Stimulus([[1, 2]], electrodes='B1')) with pytest.raises(NotImplementedError): # negative time axis: stim.append(Stimulus([[0, 2]], time=[-1, 0]))
def test_Stimulus___eq__(): # Two Stimulus objects created from the same source data are considered # equal: for source in [3, [], np.ones(3), [3, 4, 5], np.ones((3, 6))]: npt.assert_equal(Stimulus(source) == Stimulus(source), True) stim = Stimulus(np.ones((2, 3)), compress=True) # Compressed vs uncompressed: npt.assert_equal(stim == Stimulus(np.ones((2, 3)), compress=False), False) npt.assert_equal(stim != Stimulus(np.ones((2, 3)), compress=False), True) # Different electrode names: npt.assert_equal(stim == Stimulus(stim, electrodes=[0, 'A2']), False) # Different time points: npt.assert_equal(stim == Stimulus(stim, time=[0, 3], compress=True), False) # Different data shape: npt.assert_equal(stim == Stimulus(np.ones((2, 4))), False) npt.assert_equal(stim == Stimulus(np.ones(2)), False) # Different data points: npt.assert_equal(stim == Stimulus(np.ones((2, 3)) * 1.1, compress=True), False) # Different shape npt.assert_equal(stim == Stimulus(np.ones((2, 5))), False) # Different type: npt.assert_equal(stim == ODict(), False) npt.assert_equal(stim != ODict(), True) # Time vs no time: npt.assert_equal(Stimulus(2) == stim, False) # Annoying but possible: npt.assert_equal(Stimulus([]), Stimulus(()))
levels = np.linspace(-1, 1, num=5) data = levels[np.argmin(np.abs(x[:, np.newaxis] - levels), axis=1)] plt.plot(t, data, label='discretized') plt.plot(t, x, label='original') plt.xlabel('Time (ms)') plt.ylabel('Amplitude') plt.legend() ############################################################################## # We can turn this signal into a :py:class:`~pulse2percept.stimuli.Stimulus` # object as follows: from pulse2percept.stimuli import Stimulus stim = Stimulus(10 * data.reshape((1, -1)), time=t) stim.plot() ############################################################################## # Alternatively, we can automate this process by creating a new class # ``SinusoidalPulse`` that inherits from ``Stimulus``: class SinusoidalPulse(Stimulus): def __init__(self, amp, freq, phase, stim_dur, n_levels=5, dt=0.001): """Sinusoidal pulse Parameters ---------- amp : float Maximum stimulus amplitude (uA)
def test_Stimulus___getitem__(): stim = Stimulus(1 + np.arange(12).reshape((3, 4))) # Slicing: npt.assert_equal(stim[:], stim.data) npt.assert_equal(stim[...], stim.data) npt.assert_equal(stim[:, :], stim.data) npt.assert_equal(stim[:2], stim.data[:2]) npt.assert_equal(stim[:, 0.0], stim.data[:, 0].reshape((-1, 1))) npt.assert_equal(stim[0, :], stim.data[0, :]) npt.assert_equal(stim[0, ...], stim.data[0, ...]) npt.assert_equal(stim[..., 0], stim.data[..., 0].reshape((-1, 1))) # More advanced slicing of time is possible, but needs a step size: with pytest.raises(ValueError): stim[:, 2:5] with pytest.raises(ValueError): stim[:, :3] with pytest.raises(ValueError): stim[:, 2:] npt.assert_almost_equal(stim[0, 1.2:1.65:0.15], [[2.2, 2.35, 2.5]]) npt.assert_almost_equal(stim[0, :0.6:0.2], [[1.0, 1.2, 1.4]]) npt.assert_almost_equal(stim[0, 2.7::0.2], [[3.7, 3.9]]) npt.assert_almost_equal(stim[0, ::2.6], [[1.0, 3.6]]) # Single element: npt.assert_equal(stim[0, 0], stim.data[0, 0]) # Interpolating time: npt.assert_almost_equal(stim[0, 2.6], 3.6) npt.assert_almost_equal(stim[..., 2.3], np.array([[3.3], [7.3], [11.3]]), decimal=3) # The second dimension is not a column index! npt.assert_almost_equal(stim[0, 0], 1.0) npt.assert_almost_equal(stim[0, [0, 1]], np.array([[1.0, 2.0]])) npt.assert_almost_equal(stim[0, [0.21, 1]], np.array([[1.21, 2.0]])) npt.assert_almost_equal(stim[[0, 1], [0.21, 1]], np.array([[1.21, 2.0], [5.21, 6.0]])) # "Valid" index errors: with pytest.raises(IndexError): stim[10, :] with pytest.raises(IndexError): stim[3.3, 0] # Times can be extrapolated (take on value of end points): stim = Stimulus(1 + np.arange(12).reshape((3, 4))) npt.assert_almost_equal(stim[0, 9.901], 4) # If time=None, you cannot interpolate/extrapolate: stim = Stimulus([3, 4, 5]) npt.assert_almost_equal(stim[0], stim.data[0, 0]) with pytest.raises(ValueError): stim[0, 0.2] # With a single time point, interpolate is still possible: stim = Stimulus(np.arange(3).reshape((-1, 1))) npt.assert_almost_equal(stim[0], stim.data[0, 0]) npt.assert_almost_equal(stim[0, 0], stim.data[0, 0]) npt.assert_almost_equal(stim[0, 3.33], stim.data[0, 0]) # Annoying but possible: stim = Stimulus([]) npt.assert_almost_equal(stim[:], stim.data) with pytest.raises(IndexError): stim[0] # Electrodes by string: stim = Stimulus([[0, 1], [2, 3]], electrodes=['A1', 'B2']) npt.assert_almost_equal(stim['A1'], [0, 1]) npt.assert_almost_equal(stim['A1', :], [0, 1]) npt.assert_almost_equal(stim[['A1', 'B2'], 0], [[0], [2]]) npt.assert_almost_equal(stim[['A1', 'B2'], :], stim.data) # Electrodes by slice: stim = Stimulus(np.arange(10)) npt.assert_almost_equal(stim[1::3], np.array([[1], [4], [7]])) # Binary arrays: stim = Stimulus(np.arange(6).reshape((2, 3)), electrodes=['A1', 'B2'], time=[0.1, 0.3, 0.5]) npt.assert_almost_equal(stim[stim.electrodes != 'A1', :], [[3, 4, 5]]) npt.assert_almost_equal(stim[stim.electrodes == 'B2', :], [[3, 4, 5]]) npt.assert_almost_equal(stim[stim.electrodes == 'C9', :], np.zeros((0, 3))) npt.assert_almost_equal(stim[stim.electrodes == 'C9', 0.1].size, 0) npt.assert_almost_equal(stim[stim.electrodes == 'B2', 0.1001], 3.0005, decimal=3) npt.assert_almost_equal(stim[stim.electrodes == 'B2', 0.2], 3.5) npt.assert_almost_equal(stim[:, stim.time < 0.4], [[0, 1], [3, 4]]) npt.assert_almost_equal(stim[stim.electrodes == 'B2', stim.time < 0.4], [3, 4]) npt.assert_almost_equal(stim[:, stim.time > 0.6], np.zeros((2, 0))) npt.assert_almost_equal(stim['A1', stim.time > 0.6].size, 0) npt.assert_almost_equal(stim['A1', np.isclose(stim.time, 0.3)], [1])
############################################################################### # Now we need to activate one electrode at a time, and predict the resulting # percept. We could build a :py:class:`~pulse2percept.stimuli.Stimulus` object # with a for loop that does just that, or we can use the following trick. # # The stimulus' data container is a (electrodes, timepoints) shaped 2D NumPy # array. Activating one electrode at a time is therefore the same as an # identity matrix whose size is equal to the number of electrodes. In code: # Find the names of all the electrodes in the dataset: electrodes = data.electrode.unique() # Activate one electrode at a time: import numpy as np from pulse2percept.stimuli import Stimulus argus.stim = Stimulus(np.eye(len(electrodes)), electrodes=electrodes) ############################################################################### # Using the model's # :py:func:`~pulse2percept.models.AxonMapModel.predict_percept`, we then get # a Percept object where each frame is the percept generated from activating # a single electrode: percepts = model.predict_percept(argus) percepts.play() ############################################################################### # Finally, we can visualize the ground-truth and simulated phosphenes # side-by-side: from pulse2percept.viz import plot_argus_simulated_phosphenes
############################################################################## # Simplest stimulus # --------------------- # :py:class:`~pulse2percept.stimuli.Stimulus` is the base class to generate # different types of stimuli. The simplest way to instantiate a Stimulus is # to pass a scalar value which is interpreted as the current amplitude # for a single electrode. # Let's start by importing necessary modules from pulse2percept.stimuli import (MonophasicPulse, BiphasicPulse, Stimulus, PulseTrain) import numpy as np stim = Stimulus(10) ############################################################################## # Parameters we don't specify will take on default values. We can inspect # all current model parameters as follows: print(stim) ############################################################################## # This command also reveals a number of other parameters to set, such as: # # * ``electrodes``: We can either specify the electrodes in the source # or within the stimulus. If none are specified it looks up the source # electrode. # # * ``metadata``: Optionally we can include metadata to the stimulus we