def test_AxonMapModel(engine): set_params = {'xystep': 2, 'engine': engine, 'rho': 432, 'axlambda': 2, 'n_axons': 9, 'n_ax_segments': 50, 'xrange': (-30, 30), 'yrange': (-20, 20), 'loc_od_x': 5, 'loc_od_y': 6} model = AxonMapModel() for param in set_params: npt.assert_equal(hasattr(model.spatial, param), True) # User can override default values for key, value in set_params.items(): setattr(model.spatial, key, value) npt.assert_equal(getattr(model.spatial, key), value) model = AxonMapModel(**set_params) model.build(**set_params) for key, value in set_params.items(): npt.assert_equal(getattr(model.spatial, key), value) # Zeros in, zeros out: implant = ArgusII(stim=np.zeros(60)) npt.assert_almost_equal(model.predict_percept(implant).data, 0) implant.stim = np.zeros(60) npt.assert_almost_equal(model.predict_percept(implant).data, 0) # Implant and model must be built for same eye: with pytest.raises(ValueError): implant = ArgusII(eye='LE', stim=np.zeros(60)) model.predict_percept(implant) with pytest.raises(ValueError): AxonMapModel(eye='invalid').build() with pytest.raises(ValueError): AxonMapModel(xystep=5).build(eye='invalid')
def test_AxonMapModel_find_closest_axon(engine): model = AxonMapModel(xystep=1, engine=engine, n_axons=5, xrange=(-20, 20), yrange=(-15, 15), axons_range=(-45, 45)) model.build() # Pretend there is an axon close to each point on the grid: bundles = [ np.array([x + 0.001, y - 0.001], dtype=np.float32).reshape((1, 2)) for x, y in zip(model.spatial.grid.xret.ravel(), model.spatial.grid.yret.ravel()) ] closest = model.spatial.find_closest_axon(bundles) for ax1, ax2 in zip(bundles, closest): npt.assert_almost_equal(ax1[0, 0], ax2[0, 0]) npt.assert_almost_equal(ax1[0, 1], ax2[0, 1]) # Looking up just one point does not return a list of axons: axon = bundles[0] closest = model.spatial.find_closest_axon(bundles, xret=axon[0, 0], yret=axon[0, 1]) npt.assert_almost_equal(closest, axon) # Return the index as well: closest, closest_idx = model.spatial.find_closest_axon(bundles, xret=axon[0, 0], yret=axon[0, 1], return_index=True) npt.assert_almost_equal(closest, axon) npt.assert_equal(closest_idx, 0)
def test_AxonMapModel(engine): set_params = { 'xystep': 2, 'engine': engine, 'rho': 432, 'axlambda': 20, 'n_axons': 9, 'n_ax_segments': 50, 'xrange': (-30, 30), 'yrange': (-20, 20), 'loc_od': (5, 6) } model = AxonMapModel() for param in set_params: npt.assert_equal(hasattr(model.spatial, param), True) # User can override default values for key, value in set_params.items(): setattr(model.spatial, key, value) npt.assert_equal(getattr(model.spatial, key), value) model = AxonMapModel(**set_params) model.build(**set_params) for key, value in set_params.items(): npt.assert_equal(getattr(model.spatial, key), value) # Converting ret <=> dva npt.assert_equal(isinstance(model.retinotopy, Watson2014Map), True) npt.assert_almost_equal(model.retinotopy.ret2dva(0, 0), (0, 0)) npt.assert_almost_equal(model.retinotopy.dva2ret(0, 0), (0, 0)) model2 = AxonMapModel(retinotopy=Watson2014DisplaceMap()) npt.assert_equal(isinstance(model2.retinotopy, Watson2014DisplaceMap), True) # Zeros in, zeros out: implant = ArgusII(stim=np.zeros(60)) npt.assert_almost_equal(model.predict_percept(implant).data, 0) implant.stim = np.zeros(60) npt.assert_almost_equal(model.predict_percept(implant).data, 0) # Implant and model must be built for same eye: with pytest.raises(ValueError): implant = ArgusII(eye='LE', stim=np.zeros(60)) model.predict_percept(implant) with pytest.raises(ValueError): AxonMapModel(eye='invalid').build() with pytest.raises(ValueError): AxonMapModel(xystep=5).build(eye='invalid') # Lambda cannot be too small: with pytest.raises(ValueError): AxonMapModel(axlambda=9).build()
def test_AxonMapModel_find_closest_axon(engine): model = AxonMapModel(xystep=1, engine=engine, n_axons=5, xrange=(-20, 20), yrange=(-15, 15), axons_range=(-45, 45)) model.build() # Pretend there is an axon close to each point on the grid: bundles = [np.array([x + 0.001, y - 0.001], dtype=np.float32).reshape((1, 2)) for x, y in zip(model.spatial.grid.xret.ravel(), model.spatial.grid.yret.ravel())] closest = model.spatial.find_closest_axon(bundles) for ax1, ax2 in zip(bundles, closest): npt.assert_almost_equal(ax1[0, 0], ax2[0, 0]) npt.assert_almost_equal(ax1[0, 1], ax2[0, 1])
def test_AxonMapModel_predict_percept(engine): model = AxonMapModel(xystep=0.55, axlambda=100, rho=100, thresh_percept=0, engine=engine, xrange=(-20, 20), yrange=(-15, 15), n_axons=500) model.build() # Single-electrode stim: img_stim = np.zeros(60) img_stim[47] = 1 percept = model.predict_percept(ArgusII(stim=img_stim)) # Single bright pixel, rest of arc is less bright: npt.assert_equal(np.sum(percept.data > 0.8), 1) npt.assert_equal(np.sum(percept.data > 0.6), 2) npt.assert_equal(np.sum(percept.data > 0.1), 7) npt.assert_equal(np.sum(percept.data > 0.0001), 32) # Overall only a few bright pixels: npt.assert_almost_equal(np.sum(percept.data), 3.31321, decimal=3) # Brightest pixel is in lower right: npt.assert_almost_equal(percept.data[33, 46, 0], np.max(percept.data)) # Top half is empty: npt.assert_almost_equal(np.sum(percept.data[:27, :, 0]), 0) # Same for lower band: npt.assert_almost_equal(np.sum(percept.data[39:, :, 0]), 0) # Full Argus II with small lambda: 60 bright spots model = AxonMapModel(engine='serial', xystep=1, rho=100, axlambda=40, xrange=(-20, 20), yrange=(-15, 15), n_axons=500) model.build() percept = model.predict_percept(ArgusII(stim=np.ones(60))) # Most spots are pretty bright, but there are 2 dimmer ones (due to their # location on the retina): npt.assert_equal(np.sum(percept.data > 0.5), 28) npt.assert_equal(np.sum(percept.data > 0.275), 56) # Model gives same outcome as Spatial: spatial = AxonMapSpatial(engine='serial', xystep=1, rho=100, axlambda=40) spatial.build() spatial_percept = model.predict_percept(ArgusII(stim=np.ones(60))) npt.assert_almost_equal(percept.data, spatial_percept.data) npt.assert_equal(percept.time, None)
def test_AxonMapModel_calc_axon_contribution(engine): model = AxonMapModel(xystep=2, engine=engine, n_axons=10, xrange=(-20, 20), yrange=(-15, 15), axons_range=(-30, 30)) model.build() xyret = np.column_stack((model.spatial.grid.xret.ravel(), model.spatial.grid.yret.ravel())) bundles = model.spatial.grow_axon_bundles() axons = model.spatial.find_closest_axon(bundles) contrib = model.spatial.calc_axon_contribution(axons) # Check lambda math: for ax, xy in zip(contrib, xyret): axon = np.insert(ax, 0, list(xy) + [0], axis=0) d2 = np.cumsum(np.diff(axon[:, 0], axis=0) ** 2 + np.diff(axon[:, 1], axis=0) ** 2) sensitivity = np.exp(-d2 / (2.0 * model.spatial.axlambda ** 2)) npt.assert_almost_equal(sensitivity, ax[:, 2])
def test_AxonMapModel_calc_axon_sensitivity(engine): model = AxonMapModel(xystep=2, engine=engine, n_axons=10, xrange=(-20, 20), yrange=(-15, 15), axons_range=(-30, 30)) model.build() xyret = np.column_stack( (model.spatial.grid.xret.ravel(), model.spatial.grid.yret.ravel())) bundles = model.spatial.grow_axon_bundles() axons = model.spatial.find_closest_axon(bundles) # Need two separate contribs, one to get cut off axons from, and another # to actually test against (with/without padding) contrib = model.spatial.calc_axon_sensitivity(axons, pad=False) pad = engine == 'jax' axon_contrib = model.spatial.calc_axon_sensitivity(axons, pad=pad) # Check lambda math: max_axon_length = max([len(ax) for ax in contrib]) for ax, xy, model_ax in zip(contrib, xyret, axon_contrib): axon = np.insert(ax, 0, list(xy) + [0], axis=0) d2 = np.cumsum( np.sqrt( np.diff(axon[:, 0], axis=0)**2 + np.diff(axon[:, 1], axis=0)**2))**2 max_d2 = -2.0 * model.axlambda**2 * np.log(model.min_ax_sensitivity) idx_d2 = d2 < max_d2 sensitivity = np.exp(-d2[idx_d2] / (2.0 * model.spatial.axlambda**2)) # Axons need to be padded for jax if engine == 'jax': s = np.zeros((max_axon_length)) s[:len(sensitivity)] = sensitivity if len(sensitivity) > 0: s[len(sensitivity):] = sensitivity[-1] sensitivity = s.astype(np.float32) npt.assert_almost_equal(sensitivity, model_ax[:, 2])
# * 'multiprocessing': a scheduler backed by a process pool # # .. _JobLib: https://joblib.readthedocs.io # .. _Dask: https://dask.org # # To change parameter values, either pass them directly to the constructor # above or set them by hand, like this: model.engine = 'serial' ############################################################################## # Then build the model. This is a necessary step before you can actually use # the model to predict a percept, as it performs a number of expensive setup # computations (e.g., building the axon map, calculating electric potentials): model.build() ############################################################################## # .. note:: # # You need to build a model only once. After that, you can apply any number # of stimuli -- or even apply the model to different implants -- without # having to rebuild (which takes time). # # Assigning a stimulus # -------------------- # The second step is to specify a visual prosthesis from the # :py:mod:`~pulse2percept.implants` module. # # In the following, we will create an # :py:class:`~pulse2percept.implants.ArgusII` implant. By default, the implant
def test_AxonMapModel_predict_percept(engine): model = AxonMapModel(xystep=0.55, axlambda=100, rho=100, thresh_percept=0, engine=engine, xrange=(-20, 20), yrange=(-15, 15), n_axons=500) model.build() # Single-electrode stim: img_stim = np.zeros(60) img_stim[47] = 1 # Axon map jax predict_percept not implemented yet if engine == 'jax': with pytest.raises(NotImplementedError): percept = model.predict_percept(ArgusII(stim=img_stim)) return percept = model.predict_percept(ArgusII(stim=img_stim)) # Single bright pixel, rest of arc is less bright: npt.assert_equal(np.sum(percept.data > 0.8), 1) npt.assert_equal(np.sum(percept.data > 0.6), 2) npt.assert_equal(np.sum(percept.data > 0.1), 7) npt.assert_equal(np.sum(percept.data > 0.0001), 31) # Overall only a few bright pixels: npt.assert_almost_equal(np.sum(percept.data), 3.3087, decimal=3) # Brightest pixel is in lower right: npt.assert_almost_equal(percept.data[33, 46, 0], np.max(percept.data)) # Top half is empty: npt.assert_almost_equal(np.sum(percept.data[:27, :, 0]), 0) # Same for lower band: npt.assert_almost_equal(np.sum(percept.data[39:, :, 0]), 0) # Full Argus II with small lambda: 60 bright spots model = AxonMapModel(engine='serial', xystep=1, rho=100, axlambda=40, xrange=(-20, 20), yrange=(-15, 15), n_axons=500) model.build() percept = model.predict_percept(ArgusII(stim=np.ones(60))) # Most spots are pretty bright, but there are 2 dimmer ones (due to their # location on the retina): npt.assert_equal(np.sum(percept.data > 0.5), 28) npt.assert_equal(np.sum(percept.data > 0.275), 56) # Model gives same outcome as Spatial: spatial = AxonMapSpatial(engine='serial', xystep=1, rho=100, axlambda=40, xrange=(-20, 20), yrange=(-15, 15), n_axons=500) spatial.build() spatial_percept = spatial.predict_percept(ArgusII(stim=np.ones(60))) npt.assert_almost_equal(percept.data, spatial_percept.data) npt.assert_equal(percept.time, None) # Warning for nonzero electrode-retina distances implant = ArgusI(stim=np.ones(16), z=10) msg = ("Nonzero electrode-retina distances do not have any effect on the " "model output.") assert_warns_msg(UserWarning, model.predict_percept, msg, implant)