def test_predict_batched(engine): if not has_jax: pytest.skip("Jax not installed") # Allows mix of valid Stimulus types stims = [{ 'A5': BiphasicPulseTrain(25, 4, 0.45), 'C7': BiphasicPulseTrain(50, 2.5, 0.75) }, { 'B4': BiphasicPulseTrain(3, 1, 0.32) }, Stimulus({'F3': BiphasicPulseTrain(12, 3, 1.2)})] implant = ArgusII() model = BiphasicAxonMapModel(engine=engine, xystep=2) model.build() # Import error if we dont have jax if engine != 'jax': with pytest.raises(ImportError): model.predict_percept_batched(implant, stims) return percepts_batched = model.predict_percept_batched(implant, stims) percepts_serial = [] for stim in stims: implant.stim = stim percepts_serial.append(model.predict_percept(implant)) npt.assert_equal(len(percepts_serial), len(percepts_batched)) for p1, p2 in zip(percepts_batched, percepts_serial): npt.assert_almost_equal(p1.data, p2.data)
def test_AxonMapModel(engine): set_params = {'xystep': 2, 'engine': engine, 'rho': 432, 'axlambda': 2, 'n_axons': 9, 'n_ax_segments': 50, 'xrange': (-30, 30), 'yrange': (-20, 20), 'loc_od_x': 5, 'loc_od_y': 6} model = AxonMapModel() for param in set_params: npt.assert_equal(hasattr(model.spatial, param), True) # User can override default values for key, value in set_params.items(): setattr(model.spatial, key, value) npt.assert_equal(getattr(model.spatial, key), value) model = AxonMapModel(**set_params) model.build(**set_params) for key, value in set_params.items(): npt.assert_equal(getattr(model.spatial, key), value) # Zeros in, zeros out: implant = ArgusII(stim=np.zeros(60)) npt.assert_almost_equal(model.predict_percept(implant).data, 0) implant.stim = np.zeros(60) npt.assert_almost_equal(model.predict_percept(implant).data, 0) # Implant and model must be built for same eye: with pytest.raises(ValueError): implant = ArgusII(eye='LE', stim=np.zeros(60)) model.predict_percept(implant) with pytest.raises(ValueError): AxonMapModel(eye='invalid').build() with pytest.raises(ValueError): AxonMapModel(xystep=5).build(eye='invalid')
def test_predict_spatial_jax(): # ensure jax predict spatial is equal to normal if not has_jax: pytest.skip("Jax not installed") model1 = BiphasicAxonMapModel(engine='jax', xystep=2) model2 = BiphasicAxonMapModel(engine='cython', xystep=2) model1.build() model2.build() implant = ArgusII() implant.stim = { 'A5': BiphasicPulseTrain(25, 4, 0.45), 'C7': BiphasicPulseTrain(50, 2.5, 0.75) } p1 = model1.predict_percept(implant) p2 = model2.predict_percept(implant) npt.assert_almost_equal(p1.data, p2.data, decimal=4) # test changing model parameters, make sure jax is clearing cache on build model1.axlambda = 800 model2.axlambda = 800 model1.rho = 50 model2.rho = 50 model1.build() model2.build() p1 = model1.predict_percept(implant) p2 = model2.predict_percept(implant) npt.assert_almost_equal(p1.data, p2.data, decimal=4)
def test_AxonMapModel(engine): set_params = { 'xystep': 2, 'engine': engine, 'rho': 432, 'axlambda': 20, 'n_axons': 9, 'n_ax_segments': 50, 'xrange': (-30, 30), 'yrange': (-20, 20), 'loc_od': (5, 6) } model = AxonMapModel() for param in set_params: npt.assert_equal(hasattr(model.spatial, param), True) # User can override default values for key, value in set_params.items(): setattr(model.spatial, key, value) npt.assert_equal(getattr(model.spatial, key), value) model = AxonMapModel(**set_params) model.build(**set_params) for key, value in set_params.items(): npt.assert_equal(getattr(model.spatial, key), value) # Converting ret <=> dva npt.assert_equal(isinstance(model.retinotopy, Watson2014Map), True) npt.assert_almost_equal(model.retinotopy.ret2dva(0, 0), (0, 0)) npt.assert_almost_equal(model.retinotopy.dva2ret(0, 0), (0, 0)) model2 = AxonMapModel(retinotopy=Watson2014DisplaceMap()) npt.assert_equal(isinstance(model2.retinotopy, Watson2014DisplaceMap), True) # Zeros in, zeros out: implant = ArgusII(stim=np.zeros(60)) npt.assert_almost_equal(model.predict_percept(implant).data, 0) implant.stim = np.zeros(60) npt.assert_almost_equal(model.predict_percept(implant).data, 0) # Implant and model must be built for same eye: with pytest.raises(ValueError): implant = ArgusII(eye='LE', stim=np.zeros(60)) model.predict_percept(implant) with pytest.raises(ValueError): AxonMapModel(eye='invalid').build() with pytest.raises(ValueError): AxonMapModel(xystep=5).build(eye='invalid') # Lambda cannot be too small: with pytest.raises(ValueError): AxonMapModel(axlambda=9).build()
# will be centered over the fovea (at x=0, y=0) and aligned with the horizontal # meridian (rot=0): from pulse2percept.implants import ArgusII implant = ArgusII() ############################################################################## # The easiest way to assign a stimulus to the implant is to pass a NumPy array # that specifies the current amplitude to be applied to every electrode in the # implant. # # For example, the following sends 10 microamps to all 60 electrodes of the # implant: import numpy as np implant.stim = 10 * np.ones(60) ############################################################################## # .. note:: # # Some models can handle stimuli that have both a spatial and a temporal # component. the scoreboard model cannot. # # 3. Predicting the percept # ------------------------- # The third step is to apply the model to predict the percept resulting from # the specified stimulus. Note that this may take some time on your machine: percept = model.predict_percept(implant) ##############################################################################
model.plot() implant.plot() ############################################################################## # By default, the plots will be added to the current Axes object. # Alternatively, you can pass ``ax=`` to specify in which Axes to plot. # # The easiest way to assign a stimulus to the implant is to pass a NumPy array # that specifies the current amplitude to be applied to every electrode in the # implant. # # For example, the following sends 1 microamp to all 60 electrodes of the # implant: import numpy as np implant.stim = np.ones(60) ############################################################################## # Predicting the percept # ---------------------- # The third step is to apply the model to predict the percept resulting from # the specified stimulus. Note that this may take some time on your machine: percept = model.predict_percept(implant) ############################################################################## # The resulting percept is stored in a # :py:class:`~pulse2percept.percepts.Percept` object, which is similar in # organization to the :py:class:`~pulse2percept.stimuli.Stimulus` object: # the ``data`` container is a 3D NumPy array (Y, X, T) with labeled axes # ``xdva``, ``ydva``, and ``time``.
def test_biphasicAxonMapSpatial(engine): # Lambda cannot be too small: with pytest.raises(ValueError): BiphasicAxonMapSpatial(axlambda=9).build() model = BiphasicAxonMapModel(engine=engine, xystep=2).build() # Jax not implemented yet if engine == 'jax': with pytest.raises(NotImplementedError): implant = ArgusII() implant.stim = Stimulus({'A5': BiphasicPulseTrain(20, 1, 0.45)}) percept = model.predict_percept(implant) return # Only accepts biphasic pulse trains with no delay dur implant = ArgusI(stim=np.ones(16)) with pytest.raises(TypeError): model.predict_percept(implant) # Nothing in, None out: npt.assert_equal(model.predict_percept(ArgusI()), None) # Zero in = zero out: implant = ArgusI(stim=np.zeros(16)) percept = model.predict_percept(implant) npt.assert_equal(isinstance(percept, Percept), True) npt.assert_equal(percept.shape, list(model.grid.x.shape) + [1]) npt.assert_almost_equal(percept.data, 0) npt.assert_equal(percept.time, None) # Should be equal to axon map model if effects models return 1 model = BiphasicAxonMapSpatial(engine=engine, xystep=2) def bright_model(freq, amp, pdur): return 1 def size_model(freq, amp, pdur): return 1 def streak_model(freq, amp, pdur): return 1 model.bright_model = bright_model model.size_model = size_model model.streak_model = streak_model model.build() axon_map = AxonMapSpatial(xystep=2).build() implant = ArgusII() implant.stim = Stimulus({'A5': BiphasicPulseTrain(20, 1, 0.45)}) percept = model.predict_percept(implant) percept_axon = axon_map.predict_percept(implant) npt.assert_almost_equal(percept.data[:, :, 0], percept_axon.get_brightest_frame()) # Effect models must be callable model = BiphasicAxonMapSpatial(engine=engine, xystep=2) model.bright_model = 1.0 with pytest.raises(TypeError): model.build() # If t_percept is not specified, there should only be one frame model = BiphasicAxonMapSpatial(engine=engine, xystep=2) model.build() implant = ArgusII() implant.stim = Stimulus({'A5': BiphasicPulseTrain(20, 1, 0.45)}) percept = model.predict_percept(implant) npt.assert_equal(percept.time is None, True) # If t_percept is specified, only first frame should have data # and the rest should be empty percept = model.predict_percept(implant, t_percept=[0, 1, 2, 5, 10]) npt.assert_equal(len(percept.time), 5) npt.assert_equal(np.any(percept.data[:, :, 0]), True) npt.assert_equal(np.any(percept.data[:, :, 1:]), False) # Test that default models give expected values model = BiphasicAxonMapSpatial(engine=engine, rho=400, axlambda=600, xystep=1, xrange=(-20, 20), yrange=(-15, 15)) model.build() implant = ArgusII() implant.stim = Stimulus({'A4': BiphasicPulseTrain(20, 1, 1)}) percept = model.predict_percept(implant) npt.assert_equal(np.sum(percept.data > 0.0813), 81) npt.assert_equal(np.sum(percept.data > 0.1626), 59) npt.assert_equal(np.sum(percept.data > 0.2439), 44) npt.assert_equal(np.sum(percept.data > 0.4065), 26) npt.assert_equal(np.sum(percept.data > 0.5691), 14)
def test_biphasicAxonMapModel(engine): set_params = { 'xystep': 2, 'engine': engine, 'rho': 432, 'axlambda': 20, 'n_axons': 9, 'n_ax_segments': 50, 'xrange': (-30, 30), 'yrange': (-20, 20), 'loc_od': (5, 6), 'do_thresholding': False } model = BiphasicAxonMapModel(engine=engine) for param in set_params: npt.assert_equal(hasattr(model.spatial, param), True) # We can set and get effects model params for atr in ['a' + str(i) for i in range(0, 10)]: npt.assert_equal(hasattr(model, atr), True) model.a0 = 5 # Should propogate to size and bright model # But should not be a member of streak or spatial npt.assert_equal(model.spatial.size_model.a0, 5) npt.assert_equal(model.spatial.bright_model.a0, 5) npt.assert_equal(hasattr(model.spatial.streak_model, 'a0'), False) with pytest.raises(AttributeError): model.spatial.__getattribute__('a0') # If the spatial model and an effects model have a parameter with the # Same name, both need to be changed model.rho = 350 model.axlambda = 450 model.do_thresholding = True npt.assert_equal(model.spatial.size_model.rho, 350) npt.assert_equal(model.spatial.streak_model.axlambda, 450) npt.assert_equal(model.spatial.bright_model.do_thresholding, True) npt.assert_equal(model.rho, 350) npt.assert_equal(model.axlambda, 450) npt.assert_equal(model.do_thresholding, True) # Effect model parameters can be passed even in constructor model = BiphasicAxonMapModel(engine=engine, a0=5, rho=432) npt.assert_equal(model.a0, 5) npt.assert_equal(model.spatial.bright_model.a0, 5) npt.assert_equal(model.rho, 432) npt.assert_equal(model.spatial.size_model.rho, 432) # If parameter is not an effects model param, it cant be set with pytest.raises(FreezeError): model.invalid_param = 5 # Custom parameters also propogate to effects models model = BiphasicAxonMapModel(engine=engine) class TestSizeModel(): def __init__(self): self.test_param = 5 def __call__(self, freq, amp, pdur): return 1 model.size_model = TestSizeModel() model.test_param = 10 npt.assert_equal(model.spatial.size_model.test_param, 10) with pytest.raises(AttributeError): model.spatial.__getattribute__('test_param') # Values are passed correctly even in another classes __init__ # This also tests for recursion error in another classes __init__ class TestInitClassGood(): def __init__(self): self.model = BiphasicAxonMapModel() # This shouldnt raise an error self.model.a0 class TestInitClassBad(): def __init__(self): self.model = BiphasicAxonMapModel() # This should self.model.a10 = 999 # If this fails, something is wrong with getattr / setattr logic TestInitClassGood() with pytest.raises(FreezeError): TestInitClassBad() # User can override default values model = BiphasicAxonMapModel(engine=engine) for key, value in set_params.items(): setattr(model.spatial, key, value) npt.assert_equal(getattr(model.spatial, key), value) model = BiphasicAxonMapModel(**set_params) model.build(**set_params) for key, value in set_params.items(): npt.assert_equal(getattr(model.spatial, key), value) # Zeros in, zeros out: implant = ArgusII(stim=np.zeros(60)) npt.assert_almost_equal(model.predict_percept(implant).data, 0) implant.stim = np.zeros(60) npt.assert_almost_equal(model.predict_percept(implant).data, 0) # Implant and model must be built for same eye: with pytest.raises(ValueError): implant = ArgusII(eye='LE', stim=np.zeros(60)) model.predict_percept(implant) with pytest.raises(ValueError): BiphasicAxonMapModel(eye='invalid').build() with pytest.raises(ValueError): BiphasicAxonMapModel(xystep=5).build(eye='invalid') # Lambda cannot be too small: with pytest.raises(ValueError): BiphasicAxonMapModel(axlambda=9).build()
# :py:class:`~pulse2percept.implants.ArgusII` implant and use the # :py:class:`~pulse2percept.models.AxonMapModel` [Beyeler2019]_ to interpret it: # # .. important :: # # Don't forget to build the model before using ``predict_percept`` # from pulse2percept.implants import ArgusII from pulse2percept.models import AxonMapModel model = AxonMapModel() model.build() implant = ArgusII() implant.stim = GratingStimulus((25, 25), temporal_freq=0.1) percept = model.predict_percept(implant) percept.play() ##################################################################################### # As you can see in the above code segment, the stimulus passed to the implant does # not necessarily have to have the same dimensions as the electrode grid. # This is functionality built in to the implant code: The implant will automatically # rescale the stimulus to the appropriate size. # In the case of Argus II, the stimulus would thus be downscaled to a 6x10 image. # # Pre-Processing Stimuli # ---------------------- # # Since both :py:class:`~pulse2percept.stimuli.BarStimulus` and
############################################################################### # Now we need to activate one electrode at a time, and predict the resulting # percept. We could build a :py:class:`~pulse2percept.stimuli.Stimulus` object # with a for loop that does just that, or we can use the following trick. # # The stimulus' data container is a (electrodes, timepoints) shaped 2D NumPy # array. Activating one electrode at a time is therefore the same as an # identity matrix whose size is equal to the number of electrodes. In code: # Find the names of all the electrodes in the dataset: electrodes = data.electrode.unique() # Activate one electrode at a time: import numpy as np from pulse2percept.stimuli import Stimulus argus.stim = Stimulus(np.eye(len(electrodes)), electrodes=electrodes) ############################################################################### # Using the model's # :py:func:`~pulse2percept.models.AxonMapModel.predict_percept`, we then get # a Percept object where each frame is the percept generated from activating # a single electrode: percepts = model.predict_percept(argus) percepts.play() ############################################################################### # Finally, we can visualize the ground-truth and simulated phosphenes # side-by-side: from pulse2percept.viz import plot_argus_simulated_phosphenes
# will be centered over the fovea (at x=0, y=0) and aligned with the horizontal # meridian (rot=0): from pulse2percept.implants import ArgusII implant = ArgusII() ############################################################################## # The easiest way to assign a stimulus to the implant is to pass a NumPy array # that specifies the current amplitude to be applied to every electrode in the # implant. # # For example, the following sends 10 microamps to all 60 electrodes of the # implant: import numpy as np implant.stim = 10 * np.ones(60) ############################################################################## # .. note:: # # Some models can handle stimuli that have both a spatial and a temporal # component. the scoreboard model cannot. # # Predicting the percept # ---------------------- # The third step is to apply the model to predict the percept resulting from # the specified stimulus. Note that this may take some time on your machine: percept = model.predict_percept(implant) ##############################################################################
implant.plot() plt.show() ############################################################################## # As mentioned above, the Biphasic Axon Map Model only accepts # :py:class:`~pulse2percept.stimuli.BiphasicPulseTrain` # stimuli with no :py:attr:`~pulse2percept.stimuli.BiphasicPulseTrain.delay_dur`. # The amplitude given to the BiphasicPulseTrain # is interpreted as amplitude as a factor of threshold (i.e. an amp of 1 means # 1xTh) # # You can easily assign BiphasicPulseTrains to electrodes with a dictionary # The following creates a train with 20Hz frequency, 1xTh amplitude, and 0.45ms # pulse / phase duration. implant.stim = {'A4': BiphasicPulseTrain(20, 1, 0.45)} implant.stim.plot() ############################################################################## # Finally, you can predict the percept resulting from stimulation percept = model.predict_percept(implant) ax = percept.plot() ax.set_title('Predicted percept') plt.show() ############################################################################## # Increasing the frequency will make phosphenes brighter fig, axes = plt.subplots(1, 2, sharex=True, sharey=True) implant.stim = {'A4': BiphasicPulseTrain(50, 1, 0.45)} new_percept = model.predict_percept(implant) new_percept.plot(ax=axes[1])