def test_predict_batched(engine):
    if not has_jax:
        pytest.skip("Jax not installed")

    # Allows mix of valid Stimulus types
    stims = [{
        'A5': BiphasicPulseTrain(25, 4, 0.45),
        'C7': BiphasicPulseTrain(50, 2.5, 0.75)
    }, {
        'B4': BiphasicPulseTrain(3, 1, 0.32)
    },
             Stimulus({'F3': BiphasicPulseTrain(12, 3, 1.2)})]
    implant = ArgusII()
    model = BiphasicAxonMapModel(engine=engine, xystep=2)
    model.build()
    # Import error if we dont have jax
    if engine != 'jax':
        with pytest.raises(ImportError):
            model.predict_percept_batched(implant, stims)
        return

    percepts_batched = model.predict_percept_batched(implant, stims)
    percepts_serial = []
    for stim in stims:
        implant.stim = stim
        percepts_serial.append(model.predict_percept(implant))

    npt.assert_equal(len(percepts_serial), len(percepts_batched))
    for p1, p2 in zip(percepts_batched, percepts_serial):
        npt.assert_almost_equal(p1.data, p2.data)
Beispiel #2
0
def test_AxonMapModel(engine):
    set_params = {'xystep': 2, 'engine': engine, 'rho': 432, 'axlambda': 2,
                  'n_axons': 9, 'n_ax_segments': 50,
                  'xrange': (-30, 30), 'yrange': (-20, 20),
                  'loc_od_x': 5, 'loc_od_y': 6}
    model = AxonMapModel()
    for param in set_params:
        npt.assert_equal(hasattr(model.spatial, param), True)

    # User can override default values
    for key, value in set_params.items():
        setattr(model.spatial, key, value)
        npt.assert_equal(getattr(model.spatial, key), value)
    model = AxonMapModel(**set_params)
    model.build(**set_params)
    for key, value in set_params.items():
        npt.assert_equal(getattr(model.spatial, key), value)

    # Zeros in, zeros out:
    implant = ArgusII(stim=np.zeros(60))
    npt.assert_almost_equal(model.predict_percept(implant).data, 0)
    implant.stim = np.zeros(60)
    npt.assert_almost_equal(model.predict_percept(implant).data, 0)

    # Implant and model must be built for same eye:
    with pytest.raises(ValueError):
        implant = ArgusII(eye='LE', stim=np.zeros(60))
        model.predict_percept(implant)
    with pytest.raises(ValueError):
        AxonMapModel(eye='invalid').build()
    with pytest.raises(ValueError):
        AxonMapModel(xystep=5).build(eye='invalid')
def test_predict_spatial_jax():
    # ensure jax predict spatial is equal to normal
    if not has_jax:
        pytest.skip("Jax not installed")
    model1 = BiphasicAxonMapModel(engine='jax', xystep=2)
    model2 = BiphasicAxonMapModel(engine='cython', xystep=2)
    model1.build()
    model2.build()
    implant = ArgusII()
    implant.stim = {
        'A5': BiphasicPulseTrain(25, 4, 0.45),
        'C7': BiphasicPulseTrain(50, 2.5, 0.75)
    }
    p1 = model1.predict_percept(implant)
    p2 = model2.predict_percept(implant)
    npt.assert_almost_equal(p1.data, p2.data, decimal=4)

    # test changing model parameters, make sure jax is clearing cache on build
    model1.axlambda = 800
    model2.axlambda = 800
    model1.rho = 50
    model2.rho = 50
    model1.build()
    model2.build()
    p1 = model1.predict_percept(implant)
    p2 = model2.predict_percept(implant)
    npt.assert_almost_equal(p1.data, p2.data, decimal=4)
Beispiel #4
0
def test_AxonMapModel(engine):
    set_params = {
        'xystep': 2,
        'engine': engine,
        'rho': 432,
        'axlambda': 20,
        'n_axons': 9,
        'n_ax_segments': 50,
        'xrange': (-30, 30),
        'yrange': (-20, 20),
        'loc_od': (5, 6)
    }
    model = AxonMapModel()
    for param in set_params:
        npt.assert_equal(hasattr(model.spatial, param), True)

    # User can override default values
    for key, value in set_params.items():
        setattr(model.spatial, key, value)
        npt.assert_equal(getattr(model.spatial, key), value)
    model = AxonMapModel(**set_params)
    model.build(**set_params)
    for key, value in set_params.items():
        npt.assert_equal(getattr(model.spatial, key), value)

    # Converting ret <=> dva
    npt.assert_equal(isinstance(model.retinotopy, Watson2014Map), True)
    npt.assert_almost_equal(model.retinotopy.ret2dva(0, 0), (0, 0))
    npt.assert_almost_equal(model.retinotopy.dva2ret(0, 0), (0, 0))
    model2 = AxonMapModel(retinotopy=Watson2014DisplaceMap())
    npt.assert_equal(isinstance(model2.retinotopy, Watson2014DisplaceMap),
                     True)

    # Zeros in, zeros out:
    implant = ArgusII(stim=np.zeros(60))
    npt.assert_almost_equal(model.predict_percept(implant).data, 0)
    implant.stim = np.zeros(60)
    npt.assert_almost_equal(model.predict_percept(implant).data, 0)

    # Implant and model must be built for same eye:
    with pytest.raises(ValueError):
        implant = ArgusII(eye='LE', stim=np.zeros(60))
        model.predict_percept(implant)
    with pytest.raises(ValueError):
        AxonMapModel(eye='invalid').build()
    with pytest.raises(ValueError):
        AxonMapModel(xystep=5).build(eye='invalid')

    # Lambda cannot be too small:
    with pytest.raises(ValueError):
        AxonMapModel(axlambda=9).build()
# will be centered over the fovea (at x=0, y=0) and aligned with the horizontal
# meridian (rot=0):

from pulse2percept.implants import ArgusII
implant = ArgusII()

##############################################################################
# The easiest way to assign a stimulus to the implant is to pass a NumPy array
# that specifies the current amplitude to be applied to every electrode in the
# implant.
#
# For example, the following sends 10 microamps to all 60 electrodes of the
# implant:

import numpy as np
implant.stim = 10 * np.ones(60)

##############################################################################
# .. note::
#
#     Some models can handle stimuli that have both a spatial and a temporal
#     component. the scoreboard model cannot.
#
# 3. Predicting the percept
# -------------------------
# The third step is to apply the model to predict the percept resulting from
# the specified stimulus. Note that this may take some time on your machine:

percept = model.predict_percept(implant)

##############################################################################
model.plot()
implant.plot()

##############################################################################
# By default, the plots will be added to the current Axes object.
# Alternatively, you can pass ``ax=`` to specify in which Axes to plot.
#
# The easiest way to assign a stimulus to the implant is to pass a NumPy array
# that specifies the current amplitude to be applied to every electrode in the
# implant.
#
# For example, the following sends 1 microamp to all 60 electrodes of the
# implant:

import numpy as np
implant.stim = np.ones(60)

##############################################################################
# Predicting the percept
# ----------------------
# The third step is to apply the model to predict the percept resulting from
# the specified stimulus. Note that this may take some time on your machine:

percept = model.predict_percept(implant)

##############################################################################
# The resulting percept is stored in a
# :py:class:`~pulse2percept.percepts.Percept` object, which is similar in
# organization to the :py:class:`~pulse2percept.stimuli.Stimulus` object:
# the ``data`` container is a 3D NumPy array (Y, X, T) with labeled axes
# ``xdva``, ``ydva``, and ``time``.
def test_biphasicAxonMapSpatial(engine):
    # Lambda cannot be too small:
    with pytest.raises(ValueError):
        BiphasicAxonMapSpatial(axlambda=9).build()

    model = BiphasicAxonMapModel(engine=engine, xystep=2).build()
    # Jax not implemented yet
    if engine == 'jax':
        with pytest.raises(NotImplementedError):
            implant = ArgusII()
            implant.stim = Stimulus({'A5': BiphasicPulseTrain(20, 1, 0.45)})
            percept = model.predict_percept(implant)
        return

    # Only accepts biphasic pulse trains with no delay dur
    implant = ArgusI(stim=np.ones(16))
    with pytest.raises(TypeError):
        model.predict_percept(implant)

    # Nothing in, None out:
    npt.assert_equal(model.predict_percept(ArgusI()), None)

    # Zero in = zero out:
    implant = ArgusI(stim=np.zeros(16))
    percept = model.predict_percept(implant)
    npt.assert_equal(isinstance(percept, Percept), True)
    npt.assert_equal(percept.shape, list(model.grid.x.shape) + [1])
    npt.assert_almost_equal(percept.data, 0)
    npt.assert_equal(percept.time, None)

    # Should be equal to axon map model if effects models return 1
    model = BiphasicAxonMapSpatial(engine=engine, xystep=2)

    def bright_model(freq, amp, pdur):
        return 1

    def size_model(freq, amp, pdur):
        return 1

    def streak_model(freq, amp, pdur):
        return 1

    model.bright_model = bright_model
    model.size_model = size_model
    model.streak_model = streak_model
    model.build()
    axon_map = AxonMapSpatial(xystep=2).build()
    implant = ArgusII()
    implant.stim = Stimulus({'A5': BiphasicPulseTrain(20, 1, 0.45)})
    percept = model.predict_percept(implant)
    percept_axon = axon_map.predict_percept(implant)
    npt.assert_almost_equal(percept.data[:, :, 0],
                            percept_axon.get_brightest_frame())

    # Effect models must be callable
    model = BiphasicAxonMapSpatial(engine=engine, xystep=2)
    model.bright_model = 1.0
    with pytest.raises(TypeError):
        model.build()

    # If t_percept is not specified, there should only be one frame
    model = BiphasicAxonMapSpatial(engine=engine, xystep=2)
    model.build()
    implant = ArgusII()
    implant.stim = Stimulus({'A5': BiphasicPulseTrain(20, 1, 0.45)})
    percept = model.predict_percept(implant)
    npt.assert_equal(percept.time is None, True)
    # If t_percept is specified, only first frame should have data
    # and the rest should be empty
    percept = model.predict_percept(implant, t_percept=[0, 1, 2, 5, 10])
    npt.assert_equal(len(percept.time), 5)
    npt.assert_equal(np.any(percept.data[:, :, 0]), True)
    npt.assert_equal(np.any(percept.data[:, :, 1:]), False)

    # Test that default models give expected values
    model = BiphasicAxonMapSpatial(engine=engine,
                                   rho=400,
                                   axlambda=600,
                                   xystep=1,
                                   xrange=(-20, 20),
                                   yrange=(-15, 15))
    model.build()
    implant = ArgusII()
    implant.stim = Stimulus({'A4': BiphasicPulseTrain(20, 1, 1)})
    percept = model.predict_percept(implant)
    npt.assert_equal(np.sum(percept.data > 0.0813), 81)
    npt.assert_equal(np.sum(percept.data > 0.1626), 59)
    npt.assert_equal(np.sum(percept.data > 0.2439), 44)
    npt.assert_equal(np.sum(percept.data > 0.4065), 26)
    npt.assert_equal(np.sum(percept.data > 0.5691), 14)
def test_biphasicAxonMapModel(engine):
    set_params = {
        'xystep': 2,
        'engine': engine,
        'rho': 432,
        'axlambda': 20,
        'n_axons': 9,
        'n_ax_segments': 50,
        'xrange': (-30, 30),
        'yrange': (-20, 20),
        'loc_od': (5, 6),
        'do_thresholding': False
    }
    model = BiphasicAxonMapModel(engine=engine)
    for param in set_params:
        npt.assert_equal(hasattr(model.spatial, param), True)

    # We can set and get effects model params
    for atr in ['a' + str(i) for i in range(0, 10)]:
        npt.assert_equal(hasattr(model, atr), True)
    model.a0 = 5
    # Should propogate to size and bright model
    # But should not be a member of streak or spatial
    npt.assert_equal(model.spatial.size_model.a0, 5)
    npt.assert_equal(model.spatial.bright_model.a0, 5)
    npt.assert_equal(hasattr(model.spatial.streak_model, 'a0'), False)
    with pytest.raises(AttributeError):
        model.spatial.__getattribute__('a0')
    # If the spatial model and an effects model have a parameter with the
    # Same name, both need to be changed
    model.rho = 350
    model.axlambda = 450
    model.do_thresholding = True
    npt.assert_equal(model.spatial.size_model.rho, 350)
    npt.assert_equal(model.spatial.streak_model.axlambda, 450)
    npt.assert_equal(model.spatial.bright_model.do_thresholding, True)
    npt.assert_equal(model.rho, 350)
    npt.assert_equal(model.axlambda, 450)
    npt.assert_equal(model.do_thresholding, True)

    # Effect model parameters can be passed even in constructor
    model = BiphasicAxonMapModel(engine=engine, a0=5, rho=432)
    npt.assert_equal(model.a0, 5)
    npt.assert_equal(model.spatial.bright_model.a0, 5)
    npt.assert_equal(model.rho, 432)
    npt.assert_equal(model.spatial.size_model.rho, 432)

    # If parameter is not an effects model param, it cant be set
    with pytest.raises(FreezeError):
        model.invalid_param = 5

    # Custom parameters also propogate to effects models
    model = BiphasicAxonMapModel(engine=engine)

    class TestSizeModel():
        def __init__(self):
            self.test_param = 5

        def __call__(self, freq, amp, pdur):
            return 1

    model.size_model = TestSizeModel()
    model.test_param = 10
    npt.assert_equal(model.spatial.size_model.test_param, 10)
    with pytest.raises(AttributeError):
        model.spatial.__getattribute__('test_param')

    # Values are passed correctly even in another classes __init__
    # This also tests for recursion error in another classes __init__
    class TestInitClassGood():
        def __init__(self):
            self.model = BiphasicAxonMapModel()
            # This shouldnt raise an error
            self.model.a0

    class TestInitClassBad():
        def __init__(self):
            self.model = BiphasicAxonMapModel()
            # This should
            self.model.a10 = 999

    # If this fails, something is wrong with getattr / setattr logic
    TestInitClassGood()
    with pytest.raises(FreezeError):
        TestInitClassBad()

    # User can override default values
    model = BiphasicAxonMapModel(engine=engine)
    for key, value in set_params.items():
        setattr(model.spatial, key, value)
        npt.assert_equal(getattr(model.spatial, key), value)
    model = BiphasicAxonMapModel(**set_params)
    model.build(**set_params)
    for key, value in set_params.items():
        npt.assert_equal(getattr(model.spatial, key), value)

    # Zeros in, zeros out:
    implant = ArgusII(stim=np.zeros(60))
    npt.assert_almost_equal(model.predict_percept(implant).data, 0)
    implant.stim = np.zeros(60)
    npt.assert_almost_equal(model.predict_percept(implant).data, 0)

    # Implant and model must be built for same eye:
    with pytest.raises(ValueError):
        implant = ArgusII(eye='LE', stim=np.zeros(60))
        model.predict_percept(implant)
    with pytest.raises(ValueError):
        BiphasicAxonMapModel(eye='invalid').build()
    with pytest.raises(ValueError):
        BiphasicAxonMapModel(xystep=5).build(eye='invalid')

    # Lambda cannot be too small:
    with pytest.raises(ValueError):
        BiphasicAxonMapModel(axlambda=9).build()
# :py:class:`~pulse2percept.implants.ArgusII` implant and use the
# :py:class:`~pulse2percept.models.AxonMapModel` [Beyeler2019]_ to interpret it:
#
# .. important ::
#
#   Don't forget to build the model before using ``predict_percept``
#

from pulse2percept.implants import ArgusII
from pulse2percept.models import AxonMapModel

model = AxonMapModel()
model.build()

implant = ArgusII()
implant.stim = GratingStimulus((25, 25), temporal_freq=0.1)

percept = model.predict_percept(implant)
percept.play()

#####################################################################################
# As you can see in the above code segment, the stimulus passed to the implant does
# not necessarily have to have the same dimensions as the electrode grid.
# This is functionality built in to the implant code: The implant will automatically
# rescale the stimulus to the appropriate size.
# In the case of Argus II, the stimulus would thus be downscaled to a 6x10 image.
#
# Pre-Processing Stimuli
# ----------------------
#
# Since both :py:class:`~pulse2percept.stimuli.BarStimulus` and
Beispiel #10
0
###############################################################################
# Now we need to activate one electrode at a time, and predict the resulting
# percept. We could build a :py:class:`~pulse2percept.stimuli.Stimulus` object
# with a for loop that does just that, or we can use the following trick.
#
# The stimulus' data container is a (electrodes, timepoints) shaped 2D NumPy
# array. Activating one electrode at a time is therefore the same as an
# identity matrix whose size is equal to the number of electrodes. In code:

# Find the names of all the electrodes in the dataset:
electrodes = data.electrode.unique()
# Activate one electrode at a time:
import numpy as np
from pulse2percept.stimuli import Stimulus
argus.stim = Stimulus(np.eye(len(electrodes)), electrodes=electrodes)

###############################################################################
# Using the model's
# :py:func:`~pulse2percept.models.AxonMapModel.predict_percept`, we then get
# a Percept object where each frame is the percept generated from activating
# a single electrode:

percepts = model.predict_percept(argus)
percepts.play()

###############################################################################
# Finally, we can visualize the ground-truth and simulated phosphenes
# side-by-side:

from pulse2percept.viz import plot_argus_simulated_phosphenes
# will be centered over the fovea (at x=0, y=0) and aligned with the horizontal
# meridian (rot=0):

from pulse2percept.implants import ArgusII
implant = ArgusII()

##############################################################################
# The easiest way to assign a stimulus to the implant is to pass a NumPy array
# that specifies the current amplitude to be applied to every electrode in the
# implant.
#
# For example, the following sends 10 microamps to all 60 electrodes of the
# implant:

import numpy as np
implant.stim = 10 * np.ones(60)

##############################################################################
# .. note::
#
#     Some models can handle stimuli that have both a spatial and a temporal
#     component. the scoreboard model cannot.
#
# Predicting the percept
# ----------------------
# The third step is to apply the model to predict the percept resulting from
# the specified stimulus. Note that this may take some time on your machine:

percept = model.predict_percept(implant)

##############################################################################
Beispiel #12
0
implant.plot()
plt.show()

##############################################################################
# As mentioned above, the Biphasic Axon Map Model only accepts
# :py:class:`~pulse2percept.stimuli.BiphasicPulseTrain`
# stimuli with no :py:attr:`~pulse2percept.stimuli.BiphasicPulseTrain.delay_dur`.
# The amplitude given to the BiphasicPulseTrain
# is interpreted as amplitude as a factor of threshold (i.e. an amp of 1 means
# 1xTh)
#
# You can easily assign BiphasicPulseTrains to electrodes with a dictionary
# The following creates a train with 20Hz frequency, 1xTh amplitude, and 0.45ms
# pulse / phase duration.

implant.stim = {'A4': BiphasicPulseTrain(20, 1, 0.45)}
implant.stim.plot()

##############################################################################
# Finally, you can predict the percept resulting from stimulation

percept = model.predict_percept(implant)
ax = percept.plot()
ax.set_title('Predicted percept')
plt.show()
##############################################################################
# Increasing the frequency will make phosphenes brighter
fig, axes = plt.subplots(1, 2, sharex=True, sharey=True)
implant.stim = {'A4': BiphasicPulseTrain(50, 1, 0.45)}
new_percept = model.predict_percept(implant)
new_percept.plot(ax=axes[1])