def test_parse_pulse_trains(): # Specify pulse trains in a number of different ways and make sure they # are all identical after parsing # Create some p2p.implants argus = implants.ArgusI() simple = implants.ElectrodeArray('subretinal', 0, 0, 0, 0) pt_zero = utils.TimeSeries(1, np.zeros(1000)) pt_nonzero = utils.TimeSeries(1, np.random.rand(1000)) # Test 1 # ------ # Specify wrong number of pulse trains with pytest.raises(ValueError): stimuli.parse_pulse_trains(pt_nonzero, argus) with pytest.raises(ValueError): stimuli.parse_pulse_trains([pt_nonzero], argus) with pytest.raises(ValueError): stimuli.parse_pulse_trains([pt_nonzero] * (argus.num_electrodes - 1), argus) with pytest.raises(ValueError): stimuli.parse_pulse_trains([pt_nonzero] * 2, simple) # Test 2 # ------ # Send non-zero pulse train to specific electrode el_name = 'B3' el_idx = argus.get_index(el_name) # Specify a list of 16 pulse trains (one for each electrode) pt0_in = [pt_zero] * argus.num_electrodes pt0_in[el_idx] = pt_nonzero pt0_out = stimuli.parse_pulse_trains(pt0_in, argus) # Specify a dict with non-zero pulse trains pt1_in = {el_name: pt_nonzero} pt1_out = stimuli.parse_pulse_trains(pt1_in, argus) # Make sure the two give the same result for p0, p1 in zip(pt0_out, pt1_out): npt.assert_equal(p0.data, p1.data) # Test 3 # ------ # Smoke testing stimuli.parse_pulse_trains([pt_zero] * argus.num_electrodes, argus) stimuli.parse_pulse_trains(pt_zero, simple) stimuli.parse_pulse_trains([pt_zero], simple)
def pulse2percept(self, stim, t_percept=None, tol=0.05, layers=['OFL', 'GCL', 'INL']): """Transforms an input stimulus to a percept Parameters ---------- stim : utils.TimeSeries|list|dict There are several ways to specify an input stimulus: - For a single-electrode array, pass a single pulse train; i.e., a single utils.TimeSeries object. - For a multi-electrode array, pass a list of pulse trains; i.e., one pulse train per electrode. - For a multi-electrode array, specify all electrodes that should receive non-zero pulse trains by name. t_percept : float, optional, default: inherit from `stim` object The desired time sampling of the output (seconds). tol : float, optional, default: 0.05 Ignore pixels whose effective current is smaller than a fraction `tol` of the max value. layers : list, optional, default: ['OFL', 'GCL', 'INL'] A list of retina layers to simulate (order does not matter): - 'OFL': Includes the optic fiber layer in the simulation. If omitted, the tissue activation map will not account for axon streaks. - 'GCL': Includes the ganglion cell layer in the simulation. - 'INL': Includes the inner nuclear layer in the simulation. If omitted, bipolar cell activity does not contribute to ganglion cell activity. Returns ------- A utils.TimeSeries object whose data container comprises the predicted brightness over time at each retinal location (x, y), with the last dimension of the container representing time (t). Examples -------- Simulate a single-electrode array: >>> import pulse2percept as p2p >>> implant = p2p.implants.ElectrodeArray('subretinal', 0, 0, 0) >>> stim = p2p.stimuli.PulseTrain(tsample=5e-6, freq=50, amp=20) >>> sim = p2p.Simulation(implant) >>> percept = sim.pulse2percept(stim) # doctest: +SKIP Simulate an Argus I array centered on the fovea, where a single electrode is being stimulated ('C3'): >>> import pulse2percept as p2p >>> implant = p2p.implants.ArgusI() >>> stim = {'C3': stimuli.PulseTrain(tsample=5e-6, freq=50, ... amp=20)} >>> sim = p2p.Simulation(implant) >>> resp = sim.pulse2percept(stim, implant) # doctest: +SKIP """ logging.getLogger(__name__).info("Starting pulse2percept...") # Get a flattened, all-uppercase list of layers layers = np.array([layers]).flatten() layers = np.array([l.upper() for l in layers]) # Make sure all specified layers exist not_supported = np.array( [l not in retina.SUPPORTED_LAYERS for l in layers], dtype=bool) if any(not_supported): msg = ', '.join(layers[not_supported]) msg = "Specified layer %s not supported. " % msg msg += "Choose from %s." % ', '.join(retina.SUPPORTED_LAYERS) raise ValueError(msg) # Set up all layers that haven't been set up yet self._set_layers() # Parse `stim` (either single pulse train or a list/dict of pulse # trains), and generate a list of pulse trains, one for each electrode pt_list = stimuli.parse_pulse_trains(stim, self.implant) pt_data = [pt.data for pt in pt_list] if not np.allclose([p.tsample for p in pt_list], self.gcl.tsample): e_s = "For now, all pulse trains must have the same sampling " e_s += "time step as the ganglion cell layer. In the future, " e_s += "this requirement might be relaxed." raise ValueError(e_s) # Tissue activation maps: If OFL is simulated, includes axon streaks. if 'OFL' in layers: ecs, _ = self.ofl.electrode_ecs(self.implant) else: _, ecs = self.ofl.electrode_ecs(self.implant) # Calculate the max of every current spread map lmax = np.zeros((2, ecs.shape[-1])) if 'INL' in layers: lmax[0, :] = ecs[:, :, 0, :].max(axis=(0, 1)) if ('GCL' or 'OFL') in layers: lmax[1, :] = ecs[:, :, 1, :].max(axis=(0, 1)) # `ecs_list` is a pixel by `n` list where `n` is the number of layers # being simulated. Each value in `ecs_list` is the current contributed # by each electrode for that spatial location ecs_list = [] idx_list = [] for xx in range(self.ofl.gridx.shape[1]): for yy in range(self.ofl.gridx.shape[0]): # If any of the used current spread maps at [yy, xx] are above # tolerance, we need to process that pixel process_pixel = False if 'INL' in layers: # For this pixel: Check if the ecs in any layer is large # enough compared to the max across pixels within the layer process_pixel |= np.any( ecs[yy, xx, 0, :] >= tol * lmax[0, :]) if ('GCL' or 'OFL') in layers: process_pixel |= np.any( ecs[yy, xx, 1, :] >= tol * lmax[1, :]) if process_pixel: ecs_list.append(ecs[yy, xx]) idx_list.append([yy, xx]) s_info = "tol=%.1f%%, %d/%d px selected" % (tol * 100, len(ecs_list), np.prod(ecs.shape[:2])) logging.getLogger(__name__).info(s_info) sr_list = utils.parfor(self.gcl.model_cascade, ecs_list, n_jobs=self.n_jobs, engine=self.engine, scheduler=self.scheduler, func_args=[pt_data, layers, self.use_jit]) bm = np.zeros(self.ofl.gridx.shape + (sr_list[0].data.shape[-1], )) idxer = tuple(np.array(idx_list)[:, i] for i in range(2)) bm[idxer] = [sr.data for sr in sr_list] percept = utils.TimeSeries(sr_list[0].tsample, bm) # It is possible to specify an additional sampling rate for the # percept: If different from the input sampling rate, need to resample. if t_percept != percept.tsample: percept = percept.resample(t_percept) logging.getLogger(__name__).info("Done.") return percept