예제 #1
0
def test_AxonMapModel(engine):
    set_params = {'xystep': 2, 'engine': engine, 'rho': 432, 'axlambda': 2,
                  'n_axons': 9, 'n_ax_segments': 50,
                  'xrange': (-30, 30), 'yrange': (-20, 20),
                  'loc_od_x': 5, 'loc_od_y': 6}
    model = AxonMapModel()
    for param in set_params:
        npt.assert_equal(hasattr(model.spatial, param), True)

    # User can override default values
    for key, value in set_params.items():
        setattr(model.spatial, key, value)
        npt.assert_equal(getattr(model.spatial, key), value)
    model = AxonMapModel(**set_params)
    model.build(**set_params)
    for key, value in set_params.items():
        npt.assert_equal(getattr(model.spatial, key), value)

    # Zeros in, zeros out:
    implant = ArgusII(stim=np.zeros(60))
    npt.assert_almost_equal(model.predict_percept(implant).data, 0)
    implant.stim = np.zeros(60)
    npt.assert_almost_equal(model.predict_percept(implant).data, 0)

    # Implant and model must be built for same eye:
    with pytest.raises(ValueError):
        implant = ArgusII(eye='LE', stim=np.zeros(60))
        model.predict_percept(implant)
    with pytest.raises(ValueError):
        AxonMapModel(eye='invalid').build()
    with pytest.raises(ValueError):
        AxonMapModel(xystep=5).build(eye='invalid')
예제 #2
0
def test_AxonMapModel_find_closest_axon(engine):
    model = AxonMapModel(xystep=1,
                         engine=engine,
                         n_axons=5,
                         xrange=(-20, 20),
                         yrange=(-15, 15),
                         axons_range=(-45, 45))
    model.build()

    # Pretend there is an axon close to each point on the grid:
    bundles = [
        np.array([x + 0.001, y - 0.001], dtype=np.float32).reshape((1, 2))
        for x, y in zip(model.spatial.grid.xret.ravel(),
                        model.spatial.grid.yret.ravel())
    ]
    closest = model.spatial.find_closest_axon(bundles)
    for ax1, ax2 in zip(bundles, closest):
        npt.assert_almost_equal(ax1[0, 0], ax2[0, 0])
        npt.assert_almost_equal(ax1[0, 1], ax2[0, 1])

    # Looking up just one point does not return a list of axons:
    axon = bundles[0]
    closest = model.spatial.find_closest_axon(bundles,
                                              xret=axon[0, 0],
                                              yret=axon[0, 1])
    npt.assert_almost_equal(closest, axon)

    # Return the index as well:
    closest, closest_idx = model.spatial.find_closest_axon(bundles,
                                                           xret=axon[0, 0],
                                                           yret=axon[0, 1],
                                                           return_index=True)
    npt.assert_almost_equal(closest, axon)
    npt.assert_equal(closest_idx, 0)
예제 #3
0
def test_AxonMapModel(engine):
    set_params = {
        'xystep': 2,
        'engine': engine,
        'rho': 432,
        'axlambda': 20,
        'n_axons': 9,
        'n_ax_segments': 50,
        'xrange': (-30, 30),
        'yrange': (-20, 20),
        'loc_od': (5, 6)
    }
    model = AxonMapModel()
    for param in set_params:
        npt.assert_equal(hasattr(model.spatial, param), True)

    # User can override default values
    for key, value in set_params.items():
        setattr(model.spatial, key, value)
        npt.assert_equal(getattr(model.spatial, key), value)
    model = AxonMapModel(**set_params)
    model.build(**set_params)
    for key, value in set_params.items():
        npt.assert_equal(getattr(model.spatial, key), value)

    # Converting ret <=> dva
    npt.assert_equal(isinstance(model.retinotopy, Watson2014Map), True)
    npt.assert_almost_equal(model.retinotopy.ret2dva(0, 0), (0, 0))
    npt.assert_almost_equal(model.retinotopy.dva2ret(0, 0), (0, 0))
    model2 = AxonMapModel(retinotopy=Watson2014DisplaceMap())
    npt.assert_equal(isinstance(model2.retinotopy, Watson2014DisplaceMap),
                     True)

    # Zeros in, zeros out:
    implant = ArgusII(stim=np.zeros(60))
    npt.assert_almost_equal(model.predict_percept(implant).data, 0)
    implant.stim = np.zeros(60)
    npt.assert_almost_equal(model.predict_percept(implant).data, 0)

    # Implant and model must be built for same eye:
    with pytest.raises(ValueError):
        implant = ArgusII(eye='LE', stim=np.zeros(60))
        model.predict_percept(implant)
    with pytest.raises(ValueError):
        AxonMapModel(eye='invalid').build()
    with pytest.raises(ValueError):
        AxonMapModel(xystep=5).build(eye='invalid')

    # Lambda cannot be too small:
    with pytest.raises(ValueError):
        AxonMapModel(axlambda=9).build()
예제 #4
0
def test_AxonMapModel_find_closest_axon(engine):
    model = AxonMapModel(xystep=1, engine=engine, n_axons=5,
                         xrange=(-20, 20), yrange=(-15, 15),
                         axons_range=(-45, 45))
    model.build()
    # Pretend there is an axon close to each point on the grid:
    bundles = [np.array([x + 0.001, y - 0.001],
                        dtype=np.float32).reshape((1, 2))
               for x, y in zip(model.spatial.grid.xret.ravel(),
                               model.spatial.grid.yret.ravel())]
    closest = model.spatial.find_closest_axon(bundles)
    for ax1, ax2 in zip(bundles, closest):
        npt.assert_almost_equal(ax1[0, 0], ax2[0, 0])
        npt.assert_almost_equal(ax1[0, 1], ax2[0, 1])
예제 #5
0
def test_AxonMapModel_predict_percept(engine):
    model = AxonMapModel(xystep=0.55,
                         axlambda=100,
                         rho=100,
                         thresh_percept=0,
                         engine=engine,
                         xrange=(-20, 20),
                         yrange=(-15, 15),
                         n_axons=500)
    model.build()
    # Single-electrode stim:
    img_stim = np.zeros(60)
    img_stim[47] = 1
    percept = model.predict_percept(ArgusII(stim=img_stim))
    # Single bright pixel, rest of arc is less bright:
    npt.assert_equal(np.sum(percept.data > 0.8), 1)
    npt.assert_equal(np.sum(percept.data > 0.6), 2)
    npt.assert_equal(np.sum(percept.data > 0.1), 7)
    npt.assert_equal(np.sum(percept.data > 0.0001), 32)
    # Overall only a few bright pixels:
    npt.assert_almost_equal(np.sum(percept.data), 3.31321, decimal=3)
    # Brightest pixel is in lower right:
    npt.assert_almost_equal(percept.data[33, 46, 0], np.max(percept.data))
    # Top half is empty:
    npt.assert_almost_equal(np.sum(percept.data[:27, :, 0]), 0)
    # Same for lower band:
    npt.assert_almost_equal(np.sum(percept.data[39:, :, 0]), 0)

    # Full Argus II with small lambda: 60 bright spots
    model = AxonMapModel(engine='serial',
                         xystep=1,
                         rho=100,
                         axlambda=40,
                         xrange=(-20, 20),
                         yrange=(-15, 15),
                         n_axons=500)
    model.build()
    percept = model.predict_percept(ArgusII(stim=np.ones(60)))
    # Most spots are pretty bright, but there are 2 dimmer ones (due to their
    # location on the retina):
    npt.assert_equal(np.sum(percept.data > 0.5), 28)
    npt.assert_equal(np.sum(percept.data > 0.275), 56)

    # Model gives same outcome as Spatial:
    spatial = AxonMapSpatial(engine='serial', xystep=1, rho=100, axlambda=40)
    spatial.build()
    spatial_percept = model.predict_percept(ArgusII(stim=np.ones(60)))
    npt.assert_almost_equal(percept.data, spatial_percept.data)
    npt.assert_equal(percept.time, None)
예제 #6
0
def test_AxonMapModel_calc_axon_contribution(engine):
    model = AxonMapModel(xystep=2, engine=engine, n_axons=10,
                         xrange=(-20, 20), yrange=(-15, 15),
                         axons_range=(-30, 30))
    model.build()
    xyret = np.column_stack((model.spatial.grid.xret.ravel(),
                             model.spatial.grid.yret.ravel()))
    bundles = model.spatial.grow_axon_bundles()
    axons = model.spatial.find_closest_axon(bundles)
    contrib = model.spatial.calc_axon_contribution(axons)

    # Check lambda math:
    for ax, xy in zip(contrib, xyret):
        axon = np.insert(ax, 0, list(xy) + [0], axis=0)
        d2 = np.cumsum(np.diff(axon[:, 0], axis=0) ** 2 +
                       np.diff(axon[:, 1], axis=0) ** 2)
        sensitivity = np.exp(-d2 / (2.0 * model.spatial.axlambda ** 2))
        npt.assert_almost_equal(sensitivity, ax[:, 2])
예제 #7
0
def test_AxonMapModel_calc_axon_sensitivity(engine):
    model = AxonMapModel(xystep=2,
                         engine=engine,
                         n_axons=10,
                         xrange=(-20, 20),
                         yrange=(-15, 15),
                         axons_range=(-30, 30))
    model.build()
    xyret = np.column_stack(
        (model.spatial.grid.xret.ravel(), model.spatial.grid.yret.ravel()))
    bundles = model.spatial.grow_axon_bundles()
    axons = model.spatial.find_closest_axon(bundles)
    # Need two separate contribs, one to get cut off axons from, and another
    # to actually test against (with/without padding)
    contrib = model.spatial.calc_axon_sensitivity(axons, pad=False)
    pad = engine == 'jax'
    axon_contrib = model.spatial.calc_axon_sensitivity(axons, pad=pad)

    # Check lambda math:
    max_axon_length = max([len(ax) for ax in contrib])
    for ax, xy, model_ax in zip(contrib, xyret, axon_contrib):
        axon = np.insert(ax, 0, list(xy) + [0], axis=0)
        d2 = np.cumsum(
            np.sqrt(
                np.diff(axon[:, 0], axis=0)**2 +
                np.diff(axon[:, 1], axis=0)**2))**2
        max_d2 = -2.0 * model.axlambda**2 * np.log(model.min_ax_sensitivity)
        idx_d2 = d2 < max_d2
        sensitivity = np.exp(-d2[idx_d2] / (2.0 * model.spatial.axlambda**2))
        # Axons need to be padded for jax
        if engine == 'jax':
            s = np.zeros((max_axon_length))
            s[:len(sensitivity)] = sensitivity
            if len(sensitivity) > 0:
                s[len(sensitivity):] = sensitivity[-1]
            sensitivity = s.astype(np.float32)
        npt.assert_almost_equal(sensitivity, model_ax[:, 2])
예제 #8
0
#    * 'multiprocessing': a scheduler backed by a process pool
#
# .. _JobLib: https://joblib.readthedocs.io
# .. _Dask: https://dask.org
#
# To change parameter values, either pass them directly to the constructor
# above or set them by hand, like this:

model.engine = 'serial'

##############################################################################
# Then build the model. This is a necessary step before you can actually use
# the model to predict a percept, as it performs a number of expensive setup
# computations (e.g., building the axon map, calculating electric potentials):

model.build()

##############################################################################
# .. note::
#
#     You need to build a model only once. After that, you can apply any number
#     of stimuli -- or even apply the model to different implants -- without
#     having to rebuild (which takes time).
#
# Assigning a stimulus
# --------------------
# The second step is to specify a visual prosthesis from the
# :py:mod:`~pulse2percept.implants` module.
#
# In the following, we will create an
# :py:class:`~pulse2percept.implants.ArgusII` implant. By default, the implant
예제 #9
0
def test_AxonMapModel_predict_percept(engine):
    model = AxonMapModel(xystep=0.55,
                         axlambda=100,
                         rho=100,
                         thresh_percept=0,
                         engine=engine,
                         xrange=(-20, 20),
                         yrange=(-15, 15),
                         n_axons=500)
    model.build()
    # Single-electrode stim:
    img_stim = np.zeros(60)
    img_stim[47] = 1
    # Axon map jax predict_percept not implemented yet
    if engine == 'jax':
        with pytest.raises(NotImplementedError):
            percept = model.predict_percept(ArgusII(stim=img_stim))
        return
    percept = model.predict_percept(ArgusII(stim=img_stim))
    # Single bright pixel, rest of arc is less bright:
    npt.assert_equal(np.sum(percept.data > 0.8), 1)
    npt.assert_equal(np.sum(percept.data > 0.6), 2)
    npt.assert_equal(np.sum(percept.data > 0.1), 7)
    npt.assert_equal(np.sum(percept.data > 0.0001), 31)
    # Overall only a few bright pixels:
    npt.assert_almost_equal(np.sum(percept.data), 3.3087, decimal=3)
    # Brightest pixel is in lower right:
    npt.assert_almost_equal(percept.data[33, 46, 0], np.max(percept.data))
    # Top half is empty:
    npt.assert_almost_equal(np.sum(percept.data[:27, :, 0]), 0)
    # Same for lower band:
    npt.assert_almost_equal(np.sum(percept.data[39:, :, 0]), 0)

    # Full Argus II with small lambda: 60 bright spots
    model = AxonMapModel(engine='serial',
                         xystep=1,
                         rho=100,
                         axlambda=40,
                         xrange=(-20, 20),
                         yrange=(-15, 15),
                         n_axons=500)
    model.build()
    percept = model.predict_percept(ArgusII(stim=np.ones(60)))
    # Most spots are pretty bright, but there are 2 dimmer ones (due to their
    # location on the retina):
    npt.assert_equal(np.sum(percept.data > 0.5), 28)
    npt.assert_equal(np.sum(percept.data > 0.275), 56)

    # Model gives same outcome as Spatial:
    spatial = AxonMapSpatial(engine='serial',
                             xystep=1,
                             rho=100,
                             axlambda=40,
                             xrange=(-20, 20),
                             yrange=(-15, 15),
                             n_axons=500)
    spatial.build()
    spatial_percept = spatial.predict_percept(ArgusII(stim=np.ones(60)))
    npt.assert_almost_equal(percept.data, spatial_percept.data)
    npt.assert_equal(percept.time, None)

    # Warning for nonzero electrode-retina distances
    implant = ArgusI(stim=np.ones(16), z=10)
    msg = ("Nonzero electrode-retina distances do not have any effect on the "
           "model output.")
    assert_warns_msg(UserWarning, model.predict_percept, msg, implant)