def test_ScoreboardModel_predict_percept():
    model = ScoreboardModel(xystep=0.55, rho=100, thresh_percept=0)
    model.build()
    # Single-electrode stim:
    img_stim = np.zeros(60)
    img_stim[47] = 1
    percept = model.predict_percept(ArgusII(stim=img_stim))
    # Single bright pixel, very small Gaussian kernel:
    npt.assert_equal(np.sum(percept.data > 0.9), 1)
    npt.assert_equal(np.sum(percept.data > 0.5), 1)
    npt.assert_equal(np.sum(percept.data > 0.1), 7)
    npt.assert_equal(np.sum(percept.data > 0.00001), 35)
    # Brightest pixel is in lower right:
    npt.assert_almost_equal(percept.data[34, 47, 0], np.max(percept.data))

    # Full Argus II: 60 bright spots
    model = ScoreboardModel(engine='serial', xystep=0.55, rho=100)
    model.build()
    percept = model.predict_percept(ArgusII(stim=np.ones(60)))
    npt.assert_equal(np.sum(np.isclose(percept.data, 0.9, rtol=0.1, atol=0.1)),
                     60)

    # Model gives same outcome as Spatial:
    spatial = ScoreboardSpatial(engine='serial', xystep=1, rho=100)
    spatial.build()
    spatial_percept = model.predict_percept(ArgusII(stim=np.ones(60)))
    npt.assert_almost_equal(percept.data, spatial_percept.data)
def test_predict_spatial_jax():
    # ensure jax predict spatial is equal to normal
    if not has_jax:
        pytest.skip("Jax not installed")
    model1 = BiphasicAxonMapModel(engine='jax', xystep=2)
    model2 = BiphasicAxonMapModel(engine='cython', xystep=2)
    model1.build()
    model2.build()
    implant = ArgusII()
    implant.stim = {
        'A5': BiphasicPulseTrain(25, 4, 0.45),
        'C7': BiphasicPulseTrain(50, 2.5, 0.75)
    }
    p1 = model1.predict_percept(implant)
    p2 = model2.predict_percept(implant)
    npt.assert_almost_equal(p1.data, p2.data, decimal=4)

    # test changing model parameters, make sure jax is clearing cache on build
    model1.axlambda = 800
    model2.axlambda = 800
    model1.rho = 50
    model2.rho = 50
    model1.build()
    model2.build()
    p1 = model1.predict_percept(implant)
    p2 = model2.predict_percept(implant)
    npt.assert_almost_equal(p1.data, p2.data, decimal=4)
def test_predict_batched(engine):
    if not has_jax:
        pytest.skip("Jax not installed")

    # Allows mix of valid Stimulus types
    stims = [{
        'A5': BiphasicPulseTrain(25, 4, 0.45),
        'C7': BiphasicPulseTrain(50, 2.5, 0.75)
    }, {
        'B4': BiphasicPulseTrain(3, 1, 0.32)
    },
             Stimulus({'F3': BiphasicPulseTrain(12, 3, 1.2)})]
    implant = ArgusII()
    model = BiphasicAxonMapModel(engine=engine, xystep=2)
    model.build()
    # Import error if we dont have jax
    if engine != 'jax':
        with pytest.raises(ImportError):
            model.predict_percept_batched(implant, stims)
        return

    percepts_batched = model.predict_percept_batched(implant, stims)
    percepts_serial = []
    for stim in stims:
        implant.stim = stim
        percepts_serial.append(model.predict_percept(implant))

    npt.assert_equal(len(percepts_serial), len(percepts_batched))
    for p1, p2 in zip(percepts_batched, percepts_serial):
        npt.assert_almost_equal(p1.data, p2.data)
Example #4
0
def test_AxonMapModel(engine):
    set_params = {'xystep': 2, 'engine': engine, 'rho': 432, 'axlambda': 2,
                  'n_axons': 9, 'n_ax_segments': 50,
                  'xrange': (-30, 30), 'yrange': (-20, 20),
                  'loc_od_x': 5, 'loc_od_y': 6}
    model = AxonMapModel()
    for param in set_params:
        npt.assert_equal(hasattr(model.spatial, param), True)

    # User can override default values
    for key, value in set_params.items():
        setattr(model.spatial, key, value)
        npt.assert_equal(getattr(model.spatial, key), value)
    model = AxonMapModel(**set_params)
    model.build(**set_params)
    for key, value in set_params.items():
        npt.assert_equal(getattr(model.spatial, key), value)

    # Zeros in, zeros out:
    implant = ArgusII(stim=np.zeros(60))
    npt.assert_almost_equal(model.predict_percept(implant).data, 0)
    implant.stim = np.zeros(60)
    npt.assert_almost_equal(model.predict_percept(implant).data, 0)

    # Implant and model must be built for same eye:
    with pytest.raises(ValueError):
        implant = ArgusII(eye='LE', stim=np.zeros(60))
        model.predict_percept(implant)
    with pytest.raises(ValueError):
        AxonMapModel(eye='invalid').build()
    with pytest.raises(ValueError):
        AxonMapModel(xystep=5).build(eye='invalid')
Example #5
0
def test_AxonMapModel(engine):
    set_params = {
        'xystep': 2,
        'engine': engine,
        'rho': 432,
        'axlambda': 20,
        'n_axons': 9,
        'n_ax_segments': 50,
        'xrange': (-30, 30),
        'yrange': (-20, 20),
        'loc_od': (5, 6)
    }
    model = AxonMapModel()
    for param in set_params:
        npt.assert_equal(hasattr(model.spatial, param), True)

    # User can override default values
    for key, value in set_params.items():
        setattr(model.spatial, key, value)
        npt.assert_equal(getattr(model.spatial, key), value)
    model = AxonMapModel(**set_params)
    model.build(**set_params)
    for key, value in set_params.items():
        npt.assert_equal(getattr(model.spatial, key), value)

    # Converting ret <=> dva
    npt.assert_equal(isinstance(model.retinotopy, Watson2014Map), True)
    npt.assert_almost_equal(model.retinotopy.ret2dva(0, 0), (0, 0))
    npt.assert_almost_equal(model.retinotopy.dva2ret(0, 0), (0, 0))
    model2 = AxonMapModel(retinotopy=Watson2014DisplaceMap())
    npt.assert_equal(isinstance(model2.retinotopy, Watson2014DisplaceMap),
                     True)

    # Zeros in, zeros out:
    implant = ArgusII(stim=np.zeros(60))
    npt.assert_almost_equal(model.predict_percept(implant).data, 0)
    implant.stim = np.zeros(60)
    npt.assert_almost_equal(model.predict_percept(implant).data, 0)

    # Implant and model must be built for same eye:
    with pytest.raises(ValueError):
        implant = ArgusII(eye='LE', stim=np.zeros(60))
        model.predict_percept(implant)
    with pytest.raises(ValueError):
        AxonMapModel(eye='invalid').build()
    with pytest.raises(ValueError):
        AxonMapModel(xystep=5).build(eye='invalid')

    # Lambda cannot be too small:
    with pytest.raises(ValueError):
        AxonMapModel(axlambda=9).build()
Example #6
0
def test_AxonMapModel_predict_percept(engine):
    model = AxonMapModel(xystep=0.55,
                         axlambda=100,
                         rho=100,
                         thresh_percept=0,
                         engine=engine,
                         xrange=(-20, 20),
                         yrange=(-15, 15),
                         n_axons=500)
    model.build()
    # Single-electrode stim:
    img_stim = np.zeros(60)
    img_stim[47] = 1
    percept = model.predict_percept(ArgusII(stim=img_stim))
    # Single bright pixel, rest of arc is less bright:
    npt.assert_equal(np.sum(percept.data > 0.8), 1)
    npt.assert_equal(np.sum(percept.data > 0.6), 2)
    npt.assert_equal(np.sum(percept.data > 0.1), 7)
    npt.assert_equal(np.sum(percept.data > 0.0001), 32)
    # Overall only a few bright pixels:
    npt.assert_almost_equal(np.sum(percept.data), 3.31321, decimal=3)
    # Brightest pixel is in lower right:
    npt.assert_almost_equal(percept.data[33, 46, 0], np.max(percept.data))
    # Top half is empty:
    npt.assert_almost_equal(np.sum(percept.data[:27, :, 0]), 0)
    # Same for lower band:
    npt.assert_almost_equal(np.sum(percept.data[39:, :, 0]), 0)

    # Full Argus II with small lambda: 60 bright spots
    model = AxonMapModel(engine='serial',
                         xystep=1,
                         rho=100,
                         axlambda=40,
                         xrange=(-20, 20),
                         yrange=(-15, 15),
                         n_axons=500)
    model.build()
    percept = model.predict_percept(ArgusII(stim=np.ones(60)))
    # Most spots are pretty bright, but there are 2 dimmer ones (due to their
    # location on the retina):
    npt.assert_equal(np.sum(percept.data > 0.5), 28)
    npt.assert_equal(np.sum(percept.data > 0.275), 56)

    # Model gives same outcome as Spatial:
    spatial = AxonMapSpatial(engine='serial', xystep=1, rho=100, axlambda=40)
    spatial.build()
    spatial_percept = model.predict_percept(ArgusII(stim=np.ones(60)))
    npt.assert_almost_equal(percept.data, spatial_percept.data)
    npt.assert_equal(percept.time, None)
Example #7
0
def test_AxonMapSpatial(engine):
    # AxonMapSpatial automatically sets `rho`, `axlambda`:
    model = AxonMapSpatial(engine=engine, xystep=5)

    # User can set `rho`:
    model.rho = 123
    npt.assert_equal(model.rho, 123)
    model.build(rho=987)
    npt.assert_equal(model.rho, 987)

    # Converting ret <=> dva
    npt.assert_equal(isinstance(model.retinotopy, Watson2014Map), True)
    npt.assert_almost_equal(model.retinotopy.ret2dva(0, 0), (0, 0))
    npt.assert_almost_equal(model.retinotopy.dva2ret(0, 0), (0, 0))
    model2 = AxonMapSpatial(retinotopy=Watson2014DisplaceMap())
    npt.assert_equal(isinstance(model2.retinotopy, Watson2014DisplaceMap),
                     True)

    # Nothing in, None out:
    npt.assert_equal(model.predict_percept(ArgusI()), None)

    # Zero in = zero out:
    implant = ArgusI(stim=np.zeros(16))
    percept = model.predict_percept(implant)
    npt.assert_equal(isinstance(percept, Percept), True)
    npt.assert_equal(percept.shape, list(model.grid.x.shape) + [1])
    npt.assert_almost_equal(percept.data, 0)
    npt.assert_equal(percept.time, None)

    # Lambda cannot be too small:
    with pytest.raises(ValueError):
        AxonMapSpatial(axlambda=9).build()

    # Multiple frames are processed independently:
    model = AxonMapSpatial(engine=engine,
                           rho=200,
                           axlambda=100,
                           xystep=5,
                           xrange=(-20, 20),
                           yrange=(-15, 15))
    model.build()
    # Axon map jax predict_percept not implemented yet
    if engine == 'jax':
        with pytest.raises(NotImplementedError):
            percept = model.predict_percept(
                ArgusII(stim={
                    'A1': [1, 0],
                    'B3': [0, 2]
                }))
        return
    percept = model.predict_percept(ArgusI(stim={'A1': [1, 0], 'B3': [0, 2]}))
    npt.assert_equal(percept.shape, list(model.grid.x.shape) + [2])
    pmax = percept.data.max(axis=(0, 1))
    npt.assert_almost_equal(percept.data[2, 3, 0], pmax[0])
    npt.assert_almost_equal(percept.data[2, 3, 1], 0)
    npt.assert_almost_equal(percept.data[3, 4, 0], 0)
    npt.assert_almost_equal(percept.data[3, 4, 1], pmax[1])
    npt.assert_almost_equal(percept.time, [0, 1])
def test_plot_argus_phosphenes():
    df = pd.DataFrame([
        {
            'subject': 'S1',
            'electrode': 'A1',
            'image': np.random.rand(10, 10),
            'xrange': (-10, 10),
            'yrange': (-10, 10)
        },
        {
            'subject': 'S1',
            'electrode': 'B2',
            'image': np.random.rand(10, 10),
            'xrange': (-10, 10),
            'yrange': (-10, 10)
        },
    ])
    _, ax = plt.subplots()
    plot_argus_phosphenes(df, ArgusI(), ax=ax)
    plot_argus_phosphenes(df, ArgusII(), ax=ax)

    # Add axon map:
    _, ax = plt.subplots()
    plot_argus_phosphenes(df, ArgusI(), ax=ax, axon_map=AxonMapModel())

    # Data must be a DataFrame:
    with pytest.raises(TypeError):
        plot_argus_phosphenes(np.ones(10), ArgusI())
    # DataFrame must have the required columns:
    with pytest.raises(ValueError):
        plot_argus_phosphenes(pd.DataFrame(), ArgusI())
    # Subjects must all be the same:
    with pytest.raises(ValueError):
        dff = pd.DataFrame([{'subject': 'S1'}, {'subject': 'S2'}])
        plot_argus_phosphenes(dff, ArgusI())
    # Works only for Argus:
    with pytest.raises(TypeError):
        plot_argus_phosphenes(df, AlphaAMS())
    # Works only for axon maps:
    with pytest.raises(TypeError):
        plot_argus_phosphenes(df, ArgusI(), ax=ax, axon_map=ScoreboardModel())
    # Manual subject selection
    plot_argus_phosphenes(df[df.electrode == 'B2'], ArgusI(), ax=ax)
    # If no implant given, dataframe must have additional columns:
    with pytest.raises(ValueError):
        plot_argus_phosphenes(df, ax=ax)
    df['implant_type_str'] = 'ArgusII'
    df['implant_x'] = 0
    df['implant_y'] = 0
    df['implant_rot'] = 0
    plot_argus_phosphenes(df, ax=ax)
Example #9
0
def test_ScoreboardModel_predict_percept():
    model = ScoreboardModel(xystep=0.55,
                            rho=100,
                            thresh_percept=0,
                            xrange=(-20, 20),
                            yrange=(-15, 15))
    model.build()
    # Single-electrode stim:
    img_stim = np.zeros(60)
    img_stim[47] = 1
    percept = model.predict_percept(ArgusII(stim=img_stim))
    # Single bright pixel, very small Gaussian kernel:
    npt.assert_equal(np.sum(percept.data > 0.8), 1)
    npt.assert_equal(np.sum(percept.data > 0.5), 2)
    npt.assert_equal(np.sum(percept.data > 0.1), 7)
    npt.assert_equal(np.sum(percept.data > 0.00001), 32)
    # Brightest pixel is in lower right:
    npt.assert_almost_equal(percept.data[33, 46, 0], np.max(percept.data))

    # Full Argus II: 60 bright spots
    model = ScoreboardModel(engine='serial', xystep=0.55, rho=100)
    model.build()
    percept = model.predict_percept(ArgusII(stim=np.ones(60)))
    npt.assert_equal(np.sum(np.isclose(percept.data, 0.8, rtol=0.1, atol=0.1)),
                     88)

    # Model gives same outcome as Spatial:
    spatial = ScoreboardSpatial(engine='serial', xystep=1, rho=100)
    spatial.build()
    spatial_percept = model.predict_percept(ArgusII(stim=np.ones(60)))
    npt.assert_almost_equal(percept.data, spatial_percept.data)
    npt.assert_equal(percept.time, None)

    # Warning for nonzero electrode-retina distances
    implant = ArgusI(stim=np.ones(16), z=10)
    msg = ("Nonzero electrode-retina distances do not have any effect on the "
           "model output.")
    assert_warns_msg(UserWarning, model.predict_percept, msg, implant)
Example #10
0
def test_plot_implant_on_axon_map():
    ax = plot_implant_on_axon_map(ArgusII())
    npt.assert_equal(isinstance(ax, Subplot), True)

    # Check axis limits:
    for xlim, ylim in zip([None, (-2000, 1500)], [(-3000, 1300), None]):
        ax = plot_implant_on_axon_map(ArgusII(), xlim=xlim, ylim=ylim)
        if xlim is None:
            xlim = (-4000, 4500)
        if ylim is None:
            ylim = (-2500, 3000)
        npt.assert_almost_equal(ax.get_xlim(), xlim)
        npt.assert_almost_equal(ax.get_ylim(), ylim)

    # Check optic disc center in both eyes:
    model = AxonMapSpatial()
    for eye in ['RE', 'LE']:
        for loc_od in [(15.5, 1.5), (17.9, -0.01)]:
            od = (-loc_od[0], loc_od[1]) if eye == 'LE' else loc_od
            ax = plot_implant_on_axon_map(ArgusII(eye=eye), loc_od=od)
            npt.assert_equal(len(ax.patches), 1)
            npt.assert_almost_equal(ax.patches[0].center, model.dva2ret(od))
            close(ax.figure)

    # Electrodes and quadrants can be annotated:
    for ann_el, n_el in [(True, 60), (False, 0)]:
        for ann_q, n_q in [(True, 4), (False, 0)]:
            ax = plot_implant_on_axon_map(ArgusII(),
                                          annotate_implant=ann_el,
                                          annotate_quadrants=ann_q)
            npt.assert_equal(len(ax.texts), n_el + n_q)
            npt.assert_equal(len(ax.collections[0]._paths), 60)
            close(ax.figure)

    # Stimulating electrodes are marked:
    ax = plot_implant_on_axon_map(ArgusII(stim=np.ones(60)))

    # Setting upside_down flips y axis:
    ax = plot_implant_on_axon_map(ArgusII(), upside_down=True)
    npt.assert_almost_equal(ax.get_xlim(), (-4000, 4500))
    npt.assert_almost_equal(ax.get_ylim(), (3000, -2500))

    with pytest.raises(TypeError):
        plot_implant_on_axon_map(DiskElectrode(0, 0, 0, 100))
    with pytest.raises(ValueError):
        plot_implant_on_axon_map(ArgusII(), n_bundles=0)
Example #11
0
def test_plot_argus_phosphenes():
    df = pd.DataFrame([
        {
            'subject': 'S1',
            'electrode': 'A1',
            'image': np.random.rand(10, 10),
            'img_x_dva': (-10, 10),
            'img_y_dva': (-10, 10)
        },
        {
            'subject': 'S1',
            'electrode': 'B2',
            'image': np.random.rand(10, 10),
            'img_x_dva': (-10, 10),
            'img_y_dva': (-10, 10)
        },
    ])
    _, ax = plt.subplots()
    plot_argus_phosphenes(df, ArgusI(), ax=ax)
    plot_argus_phosphenes(df, ArgusII(), ax=ax)

    # Add axon map:
    _, ax = plt.subplots()
    plot_argus_phosphenes(df, ArgusI(), ax=ax, axon_map=AxonMapModel())

    # Data must be a DataFrame:
    with pytest.raises(TypeError):
        plot_argus_phosphenes(np.ones(10), ArgusI())
    # DataFrame must have the required columns:
    with pytest.raises(ValueError):
        plot_argus_phosphenes(pd.DataFrame(), ArgusI())
    # Subjects must all be the same:
    with pytest.raises(ValueError):
        dff = pd.DataFrame([{'subject': 'S1'}, {'subject': 'S2'}])
        plot_argus_phosphenes(dff, ArgusI())
    # Works only for Argus:
    with pytest.raises(TypeError):
        plot_argus_phosphenes(df, AlphaAMS())
Example #12
0
def test_plot_implant_on_axon_map():
    fig, ax = plot_implant_on_axon_map(ArgusII())
    npt.assert_equal(isinstance(fig, Figure), True)
    npt.assert_equal(isinstance(ax, Subplot), True)

    # Check axis limits:
    xmin, xmax, ymin, ymax = dva2ret([-20, 20, -15, 15])
    npt.assert_equal(ax.get_xlim(), (xmin, xmax))
    npt.assert_equal(ax.get_ylim(), (ymin, ymax))

    # Check optic disc center in both eyes:
    for eye in ['RE', 'LE']:
        for loc_od in [(15.5, 1.5), (17.9, -0.01)]:
            od = (-loc_od[0], loc_od[1]) if eye == 'LE' else loc_od
            _, ax = plot_implant_on_axon_map(ArgusII(eye=eye), loc_od=od)
            npt.assert_equal(len(ax.patches), 1)
            npt.assert_almost_equal(ax.patches[0].center, dva2ret(od))

    # Electrodes and quadrants can be annotated:
    for ann_el, n_el in [(True, 60), (False, 0)]:
        for ann_q, n_q in [(True, 4), (False, 0)]:
            _, ax = plot_implant_on_axon_map(ArgusII(),
                                             annotate_implant=ann_el,
                                             annotate_quadrants=ann_q)
            npt.assert_equal(len(ax.texts), n_el + n_q)

    # Stimulating electrodes are marked:
    fig, ax = plot_implant_on_axon_map(ArgusII(stim=np.ones(60)))

    # Setting upside_down flips y axis:
    _, ax = plot_implant_on_axon_map(ArgusII(), upside_down=True)
    npt.assert_equal(ax.get_xlim(), (xmin, xmax))
    npt.assert_equal(ax.get_ylim(), (ymax, ymin))

    with pytest.raises(TypeError):
        plot_implant_on_axon_map(DiskElectrode(0, 0, 0, 100))
    with pytest.raises(ValueError):
        plot_implant_on_axon_map(ArgusII(), n_bundles=0)
#     You need to build a model only once. After that, you can apply any number
#     of stimuli -- or even apply the model to different implants -- without
#     having to rebuild (which takes time).
#
# 2. Assigning a stimulus
# -----------------------
# The second step is to specify a visual prosthesis from the
# :py:mod:`~pulse2percept.implants` module.
#
# In the following, we will create an
# :py:class:`~pulse2percept.implants.ArgusII` implant. By default, the implant
# will be centered over the fovea (at x=0, y=0) and aligned with the horizontal
# meridian (rot=0):

from pulse2percept.implants import ArgusII
implant = ArgusII()

##############################################################################
# The easiest way to assign a stimulus to the implant is to pass a NumPy array
# that specifies the current amplitude to be applied to every electrode in the
# implant.
#
# For example, the following sends 10 microamps to all 60 electrodes of the
# implant:

import numpy as np
implant.stim = 10 * np.ones(60)

##############################################################################
# .. note::
#
Example #14
0
def test_AxonMapModel_predict_percept(engine):
    model = AxonMapModel(xystep=0.55,
                         axlambda=100,
                         rho=100,
                         thresh_percept=0,
                         engine=engine,
                         xrange=(-20, 20),
                         yrange=(-15, 15),
                         n_axons=500)
    model.build()
    # Single-electrode stim:
    img_stim = np.zeros(60)
    img_stim[47] = 1
    # Axon map jax predict_percept not implemented yet
    if engine == 'jax':
        with pytest.raises(NotImplementedError):
            percept = model.predict_percept(ArgusII(stim=img_stim))
        return
    percept = model.predict_percept(ArgusII(stim=img_stim))
    # Single bright pixel, rest of arc is less bright:
    npt.assert_equal(np.sum(percept.data > 0.8), 1)
    npt.assert_equal(np.sum(percept.data > 0.6), 2)
    npt.assert_equal(np.sum(percept.data > 0.1), 7)
    npt.assert_equal(np.sum(percept.data > 0.0001), 31)
    # Overall only a few bright pixels:
    npt.assert_almost_equal(np.sum(percept.data), 3.3087, decimal=3)
    # Brightest pixel is in lower right:
    npt.assert_almost_equal(percept.data[33, 46, 0], np.max(percept.data))
    # Top half is empty:
    npt.assert_almost_equal(np.sum(percept.data[:27, :, 0]), 0)
    # Same for lower band:
    npt.assert_almost_equal(np.sum(percept.data[39:, :, 0]), 0)

    # Full Argus II with small lambda: 60 bright spots
    model = AxonMapModel(engine='serial',
                         xystep=1,
                         rho=100,
                         axlambda=40,
                         xrange=(-20, 20),
                         yrange=(-15, 15),
                         n_axons=500)
    model.build()
    percept = model.predict_percept(ArgusII(stim=np.ones(60)))
    # Most spots are pretty bright, but there are 2 dimmer ones (due to their
    # location on the retina):
    npt.assert_equal(np.sum(percept.data > 0.5), 28)
    npt.assert_equal(np.sum(percept.data > 0.275), 56)

    # Model gives same outcome as Spatial:
    spatial = AxonMapSpatial(engine='serial',
                             xystep=1,
                             rho=100,
                             axlambda=40,
                             xrange=(-20, 20),
                             yrange=(-15, 15),
                             n_axons=500)
    spatial.build()
    spatial_percept = spatial.predict_percept(ArgusII(stim=np.ones(60)))
    npt.assert_almost_equal(percept.data, spatial_percept.data)
    npt.assert_equal(percept.time, None)

    # Warning for nonzero electrode-retina distances
    implant = ArgusI(stim=np.ones(16), z=10)
    msg = ("Nonzero electrode-retina distances do not have any effect on the "
           "model output.")
    assert_warns_msg(UserWarning, model.predict_percept, msg, implant)
Example #15
0
# radial and axonal current spread, respectively. The parameters ``a0``-``a9`` are
# coefficients for the size, streak, and bright models, which will be discussed
# later in this example. The biphasic axon map model supports both the default
# cython engine and a faster, gpu-enabled jax engine.
#
# The rest of the parameters are shared with
# :py:class:`~pulse2percept.models.AxonMapModel`. For full details on these
# parameters, see the Axon Map Tutorial
#
#
# Next, build the model to perform expensive, one time calculations,
# and specify a visual prosthesis from the
# :py:mod:`~pulse2percept.implants` module. Models with an axon map are well
# suited for epiretinal implants, such as Argus II.
model.build()
implant = ArgusII()

##############################################################################
# .. important ::
#
#     You need to build a model only once. After that, you can apply any number
#     of stimuli -- or even apply the model to different implants -- without
#     having to rebuild (which takes time).
#
#     However, if you change model parameters
#     (e.g., by directly setting ``model.a5 = 2``), you will have to
#     call ``model.build()`` again for your changes to take effect.
#
#
# You can visualize the location of the implant and the axon map
Example #16
0
#     You need to build a model only once. After that, you can apply any number
#     of stimuli -- or even apply the model to different implants -- without
#     having to rebuild (which takes time).
#
# Assigning a stimulus
# --------------------
# The second step is to specify a visual prosthesis from the
# :py:mod:`~pulse2percept.implants` module.
#
# In the following, we will create an
# :py:class:`~pulse2percept.implants.ArgusII` implant. By default, the implant
# will be centered over the fovea (at x=0, y=0) and aligned with the horizontal
# meridian (rot=0):

from pulse2percept.implants import ArgusII
implant = ArgusII()

##############################################################################
# You can inspect the location of the implant with respect to the underlying
# nerve fiber bundles using the built-in plot methods:

model.plot()
implant.plot()

##############################################################################
# By default, the plots will be added to the current Axes object.
# Alternatively, you can pass ``ax=`` to specify in which Axes to plot.
#
# The easiest way to assign a stimulus to the implant is to pass a NumPy array
# that specifies the current amplitude to be applied to every electrode in the
# implant.
def test_biphasicAxonMapSpatial(engine):
    # Lambda cannot be too small:
    with pytest.raises(ValueError):
        BiphasicAxonMapSpatial(axlambda=9).build()

    model = BiphasicAxonMapModel(engine=engine, xystep=2).build()
    # Jax not implemented yet
    if engine == 'jax':
        with pytest.raises(NotImplementedError):
            implant = ArgusII()
            implant.stim = Stimulus({'A5': BiphasicPulseTrain(20, 1, 0.45)})
            percept = model.predict_percept(implant)
        return

    # Only accepts biphasic pulse trains with no delay dur
    implant = ArgusI(stim=np.ones(16))
    with pytest.raises(TypeError):
        model.predict_percept(implant)

    # Nothing in, None out:
    npt.assert_equal(model.predict_percept(ArgusI()), None)

    # Zero in = zero out:
    implant = ArgusI(stim=np.zeros(16))
    percept = model.predict_percept(implant)
    npt.assert_equal(isinstance(percept, Percept), True)
    npt.assert_equal(percept.shape, list(model.grid.x.shape) + [1])
    npt.assert_almost_equal(percept.data, 0)
    npt.assert_equal(percept.time, None)

    # Should be equal to axon map model if effects models return 1
    model = BiphasicAxonMapSpatial(engine=engine, xystep=2)

    def bright_model(freq, amp, pdur):
        return 1

    def size_model(freq, amp, pdur):
        return 1

    def streak_model(freq, amp, pdur):
        return 1

    model.bright_model = bright_model
    model.size_model = size_model
    model.streak_model = streak_model
    model.build()
    axon_map = AxonMapSpatial(xystep=2).build()
    implant = ArgusII()
    implant.stim = Stimulus({'A5': BiphasicPulseTrain(20, 1, 0.45)})
    percept = model.predict_percept(implant)
    percept_axon = axon_map.predict_percept(implant)
    npt.assert_almost_equal(percept.data[:, :, 0],
                            percept_axon.get_brightest_frame())

    # Effect models must be callable
    model = BiphasicAxonMapSpatial(engine=engine, xystep=2)
    model.bright_model = 1.0
    with pytest.raises(TypeError):
        model.build()

    # If t_percept is not specified, there should only be one frame
    model = BiphasicAxonMapSpatial(engine=engine, xystep=2)
    model.build()
    implant = ArgusII()
    implant.stim = Stimulus({'A5': BiphasicPulseTrain(20, 1, 0.45)})
    percept = model.predict_percept(implant)
    npt.assert_equal(percept.time is None, True)
    # If t_percept is specified, only first frame should have data
    # and the rest should be empty
    percept = model.predict_percept(implant, t_percept=[0, 1, 2, 5, 10])
    npt.assert_equal(len(percept.time), 5)
    npt.assert_equal(np.any(percept.data[:, :, 0]), True)
    npt.assert_equal(np.any(percept.data[:, :, 1:]), False)

    # Test that default models give expected values
    model = BiphasicAxonMapSpatial(engine=engine,
                                   rho=400,
                                   axlambda=600,
                                   xystep=1,
                                   xrange=(-20, 20),
                                   yrange=(-15, 15))
    model.build()
    implant = ArgusII()
    implant.stim = Stimulus({'A4': BiphasicPulseTrain(20, 1, 1)})
    percept = model.predict_percept(implant)
    npt.assert_equal(np.sum(percept.data > 0.0813), 81)
    npt.assert_equal(np.sum(percept.data > 0.1626), 59)
    npt.assert_equal(np.sum(percept.data > 0.2439), 44)
    npt.assert_equal(np.sum(percept.data > 0.4065), 26)
    npt.assert_equal(np.sum(percept.data > 0.5691), 14)
def test_biphasicAxonMapModel(engine):
    set_params = {
        'xystep': 2,
        'engine': engine,
        'rho': 432,
        'axlambda': 20,
        'n_axons': 9,
        'n_ax_segments': 50,
        'xrange': (-30, 30),
        'yrange': (-20, 20),
        'loc_od': (5, 6),
        'do_thresholding': False
    }
    model = BiphasicAxonMapModel(engine=engine)
    for param in set_params:
        npt.assert_equal(hasattr(model.spatial, param), True)

    # We can set and get effects model params
    for atr in ['a' + str(i) for i in range(0, 10)]:
        npt.assert_equal(hasattr(model, atr), True)
    model.a0 = 5
    # Should propogate to size and bright model
    # But should not be a member of streak or spatial
    npt.assert_equal(model.spatial.size_model.a0, 5)
    npt.assert_equal(model.spatial.bright_model.a0, 5)
    npt.assert_equal(hasattr(model.spatial.streak_model, 'a0'), False)
    with pytest.raises(AttributeError):
        model.spatial.__getattribute__('a0')
    # If the spatial model and an effects model have a parameter with the
    # Same name, both need to be changed
    model.rho = 350
    model.axlambda = 450
    model.do_thresholding = True
    npt.assert_equal(model.spatial.size_model.rho, 350)
    npt.assert_equal(model.spatial.streak_model.axlambda, 450)
    npt.assert_equal(model.spatial.bright_model.do_thresholding, True)
    npt.assert_equal(model.rho, 350)
    npt.assert_equal(model.axlambda, 450)
    npt.assert_equal(model.do_thresholding, True)

    # Effect model parameters can be passed even in constructor
    model = BiphasicAxonMapModel(engine=engine, a0=5, rho=432)
    npt.assert_equal(model.a0, 5)
    npt.assert_equal(model.spatial.bright_model.a0, 5)
    npt.assert_equal(model.rho, 432)
    npt.assert_equal(model.spatial.size_model.rho, 432)

    # If parameter is not an effects model param, it cant be set
    with pytest.raises(FreezeError):
        model.invalid_param = 5

    # Custom parameters also propogate to effects models
    model = BiphasicAxonMapModel(engine=engine)

    class TestSizeModel():
        def __init__(self):
            self.test_param = 5

        def __call__(self, freq, amp, pdur):
            return 1

    model.size_model = TestSizeModel()
    model.test_param = 10
    npt.assert_equal(model.spatial.size_model.test_param, 10)
    with pytest.raises(AttributeError):
        model.spatial.__getattribute__('test_param')

    # Values are passed correctly even in another classes __init__
    # This also tests for recursion error in another classes __init__
    class TestInitClassGood():
        def __init__(self):
            self.model = BiphasicAxonMapModel()
            # This shouldnt raise an error
            self.model.a0

    class TestInitClassBad():
        def __init__(self):
            self.model = BiphasicAxonMapModel()
            # This should
            self.model.a10 = 999

    # If this fails, something is wrong with getattr / setattr logic
    TestInitClassGood()
    with pytest.raises(FreezeError):
        TestInitClassBad()

    # User can override default values
    model = BiphasicAxonMapModel(engine=engine)
    for key, value in set_params.items():
        setattr(model.spatial, key, value)
        npt.assert_equal(getattr(model.spatial, key), value)
    model = BiphasicAxonMapModel(**set_params)
    model.build(**set_params)
    for key, value in set_params.items():
        npt.assert_equal(getattr(model.spatial, key), value)

    # Zeros in, zeros out:
    implant = ArgusII(stim=np.zeros(60))
    npt.assert_almost_equal(model.predict_percept(implant).data, 0)
    implant.stim = np.zeros(60)
    npt.assert_almost_equal(model.predict_percept(implant).data, 0)

    # Implant and model must be built for same eye:
    with pytest.raises(ValueError):
        implant = ArgusII(eye='LE', stim=np.zeros(60))
        model.predict_percept(implant)
    with pytest.raises(ValueError):
        BiphasicAxonMapModel(eye='invalid').build()
    with pytest.raises(ValueError):
        BiphasicAxonMapModel(xystep=5).build(eye='invalid')

    # Lambda cannot be too small:
    with pytest.raises(ValueError):
        BiphasicAxonMapModel(axlambda=9).build()
Example #19
0
    with pytest.raises(ValueError):
        plot_argus_phosphenes(pd.DataFrame(), ArgusI())
    # Subjects must all be the same:
    with pytest.raises(ValueError):
        dff = pd.DataFrame([{'subject': 'S1'}, {'subject': 'S2'}])
        plot_argus_phosphenes(dff, ArgusI())
    # Works only for Argus:
    with pytest.raises(TypeError):
        plot_argus_phosphenes(df, AlphaAMS())
    # Works only for axon maps:
    with pytest.raises(TypeError):
        plot_argus_phosphenes(df, ArgusI(), ax=ax, axon_map=ScoreboardModel())

    # Manual subject selection
    plot_argus_phosphenes(df[df.electrode == 'B2'], ArgusI(), ax=ax)


@pytest.mark.parametrize('implant', (ArgusI(), ArgusII()))
def test_plot_argus_simulated_phosphenes(implant):
    implant.stim = {'A1': [1, 0, 0], 'B2': [0, 1, 0], 'C3': [0, 0, 1]}
    percepts = ScoreboardModel().build().predict_percept(implant)

    plot_argus_simulated_phosphenes(percepts, implant)

    # Add axon map:
    _, ax = plt.subplots()
    plot_argus_simulated_phosphenes(percepts,
                                    implant,
                                    ax=ax,
                                    axon_map=AxonMapModel())
# To demonstrate, we will pass a ``GratingStimululus`` to an
# :py:class:`~pulse2percept.implants.ArgusII` implant and use the
# :py:class:`~pulse2percept.models.AxonMapModel` [Beyeler2019]_ to interpret it:
#
# .. important ::
#
#   Don't forget to build the model before using ``predict_percept``
#

from pulse2percept.implants import ArgusII
from pulse2percept.models import AxonMapModel

model = AxonMapModel()
model.build()

implant = ArgusII()
implant.stim = GratingStimulus((25, 25), temporal_freq=0.1)

percept = model.predict_percept(implant)
percept.play()

#####################################################################################
# As you can see in the above code segment, the stimulus passed to the implant does
# not necessarily have to have the same dimensions as the electrode grid.
# This is functionality built in to the implant code: The implant will automatically
# rescale the stimulus to the appropriate size.
# In the case of Argus II, the stimulus would thus be downscaled to a 6x10 image.
#
# Pre-Processing Stimuli
# ----------------------
#
Example #21
0
#     You need to build a model only once. After that, you can apply any number
#     of stimuli -- or even apply the model to different implants -- without
#     having to rebuild (which takes time).
#
# Assigning a stimulus
# --------------------
# The second step is to specify a visual prosthesis from the
# :py:mod:`~pulse2percept.implants` module.
#
# In the following, we will create an
# :py:class:`~pulse2percept.implants.ArgusII` implant. By default, the implant
# will be centered over the fovea (at x=0, y=0) and aligned with the horizontal
# meridian (rot=0):

from pulse2percept.implants import ArgusII
implant = ArgusII()

##############################################################################
# The easiest way to assign a stimulus to the implant is to pass a NumPy array
# that specifies the current amplitude to be applied to every electrode in the
# implant.
#
# For example, the following sends 10 microamps to all 60 electrodes of the
# implant:

import numpy as np
implant.stim = 10 * np.ones(60)

##############################################################################
# .. note::
#
Example #22
0
plt.imshow(data.loc[0, 'image'], cmap='gray')

###############################################################################
# However, we might be more interested in seeing how phosphene shape differs
# for different electrodes.
# For this we can use :py:func:`~pulse2percept.viz.plot_argus_phosphenes` from
# the :py:mod:`~pulse2percept.viz` module.
# In addition to the ``data`` matrix, the function will also want an
# :py:class:`~pulse2percept.implants.ArgusII` object implanted at the correct
# location.
#
# Consulting [Beyeler2019]_ tells us that the prosthesis was roughly implanted
# in the following location:

from pulse2percept.implants import ArgusII
argus = ArgusII(x=-1331, y=-850, rot=-28.4, eye='RE')

###############################################################################
# For now, let's focus on the data from Subject 2:

data = fetch_beyeler2019(subjects='S2')

###############################################################################
# Passing both ``data`` and ``argus`` to
# :py:func:`~pulse2percept.viz.plot_argus_phosphenes` will then allow the
# function to overlay the phosphene drawings over a schematic of the implant.
# Here, phosphene drawings from different trials are averaged, and aligned with
# the center of the electrode that was used to obtain the drawing:

from pulse2percept.viz import plot_argus_phosphenes
plot_argus_phosphenes(data, argus)
Example #23
0
###############################################################################
# However, we might be more interested in seeing how phosphene shape differs
# for different electrodes.
# For this we can use :py:func:`~pulse2percept.viz.plot_argus_phosphenes` from
# the :py:mod:`~pulse2percept.viz` module.
# In addition to the ``data`` matrix, the function will also want an
# :py:class:`~pulse2percept.implants.ArgusII` object implanted at the correct
# location.
#
# Consulting [Beyeler2019]_ tells us that the prosthesis was roughly implanted
# in the following location:

from pulse2percept.implants import ArgusII

argus = ArgusII(x=-1331, y=-850, rot=-0.495, eye='RE')

###############################################################################
# (We also need to specify the dimensions of the screens that the subject used,
# expressed in degrees of visual angle, so that we can scale the phosphene
# drawing appropriately. This should really be part of the Beyeler dataset and
# will be fixed in a future version.
# For now, we add the necessary columns ourselves.)
import pandas as pd

data = fetch_beyeler2019(subjects='S2')
data['img_x_dva'] = pd.Series([(-30, 30)] * len(data),
                              index=data.index,
                              dtype=float)
data['img_y_dva'] = pd.Series([(-22.5, 22.5)] * len(data),
                              index=data.index,
Example #24
0
#
#     However, if you change important model parameters outside the constructor
#     (e.g., by directly setting ``model.axlambda = 100``), you will have to
#     call ``model.build()`` again for your changes to take effect.
#
# Assigning a stimulus
# --------------------
# The second step is to specify a visual prosthesis from the
# :py:mod:`~pulse2percept.implants` module.
#
# In the following, we will create an
# :py:class:`~pulse2percept.implants.ArgusII` implant. By default, the implant
# will be centered over the fovea (at x=0, y=0) and aligned with the horizontal
# meridian (rot=0):

implant = ArgusII()

##############################################################################
# You can inspect the location of the implant with respect to the underlying
# nerve fiber bundles using the built-in plot methods:

model.plot()
implant.plot()

##############################################################################
# By default, the plots will be added to the current Axes object.
# Alternatively, you can pass ``ax=`` to specify in which Axes to plot.
#
# The easiest way to assign a stimulus to the implant is to pass a NumPy array
# that specifies the current amplitude to be applied to every electrode in the
# implant.