Exemplo n.º 1
0
def test_parse_pulse_trains():
    # Specify pulse trains in a number of different ways and make sure they
    # are all identical after parsing

    # Create some p2p.implants
    argus = implants.ArgusI()
    simple = implants.ElectrodeArray('subretinal', 0, 0, 0, 0)

    pt_zero = utils.TimeSeries(1, np.zeros(1000))
    pt_nonzero = utils.TimeSeries(1, np.random.rand(1000))

    # Test 1
    # ------
    # Specify wrong number of pulse trains
    with pytest.raises(ValueError):
        stimuli.parse_pulse_trains(pt_nonzero, argus)
    with pytest.raises(ValueError):
        stimuli.parse_pulse_trains([pt_nonzero], argus)
    with pytest.raises(ValueError):
        stimuli.parse_pulse_trains([pt_nonzero] *
                                   (argus.num_electrodes - 1),
                                   argus)
    with pytest.raises(ValueError):
        stimuli.parse_pulse_trains([pt_nonzero] * 2, simple)

    # Test 2
    # ------
    # Send non-zero pulse train to specific electrode
    el_name = 'B3'
    el_idx = argus.get_index(el_name)

    # Specify a list of 16 pulse trains (one for each electrode)
    pt0_in = [pt_zero] * argus.num_electrodes
    pt0_in[el_idx] = pt_nonzero
    pt0_out = stimuli.parse_pulse_trains(pt0_in, argus)

    # Specify a dict with non-zero pulse trains
    pt1_in = {el_name: pt_nonzero}
    pt1_out = stimuli.parse_pulse_trains(pt1_in, argus)

    # Make sure the two give the same result
    for p0, p1 in zip(pt0_out, pt1_out):
        npt.assert_equal(p0.data, p1.data)

    # Test 3
    # ------
    # Smoke testing
    stimuli.parse_pulse_trains([pt_zero] * argus.num_electrodes, argus)
    stimuli.parse_pulse_trains(pt_zero, simple)
    stimuli.parse_pulse_trains([pt_zero], simple)
Exemplo n.º 2
0
    def pulse2percept(self,
                      stim,
                      t_percept=None,
                      tol=0.05,
                      layers=['OFL', 'GCL', 'INL']):
        """Transforms an input stimulus to a percept

        Parameters
        ----------
        stim : utils.TimeSeries|list|dict
            There are several ways to specify an input stimulus:

            - For a single-electrode array, pass a single pulse train; i.e.,
              a single utils.TimeSeries object.
            - For a multi-electrode array, pass a list of pulse trains; i.e.,
              one pulse train per electrode.
            - For a multi-electrode array, specify all electrodes that should
              receive non-zero pulse trains by name.
        t_percept : float, optional, default: inherit from `stim` object
            The desired time sampling of the output (seconds).
        tol : float, optional, default: 0.05
            Ignore pixels whose effective current is smaller than a fraction
            `tol` of the max value.
        layers : list, optional, default: ['OFL', 'GCL', 'INL']
            A list of retina layers to simulate (order does not matter):
            - 'OFL': Includes the optic fiber layer in the simulation.
                     If omitted, the tissue activation map will not account
                     for axon streaks.
            - 'GCL': Includes the ganglion cell layer in the simulation.
            - 'INL': Includes the inner nuclear layer in the simulation.
                     If omitted, bipolar cell activity does not contribute
                     to ganglion cell activity.

        Returns
        -------
        A utils.TimeSeries object whose data container comprises the predicted
        brightness over time at each retinal location (x, y), with the last
        dimension of the container representing time (t).

        Examples
        --------
        Simulate a single-electrode array:

        >>> import pulse2percept as p2p
        >>> implant = p2p.implants.ElectrodeArray('subretinal', 0, 0, 0)
        >>> stim = p2p.stimuli.PulseTrain(tsample=5e-6, freq=50, amp=20)
        >>> sim = p2p.Simulation(implant)
        >>> percept = sim.pulse2percept(stim)  # doctest: +SKIP

        Simulate an Argus I array centered on the fovea, where a single
        electrode is being stimulated ('C3'):

        >>> import pulse2percept as p2p
        >>> implant = p2p.implants.ArgusI()
        >>> stim = {'C3': stimuli.PulseTrain(tsample=5e-6, freq=50,
        ...                                              amp=20)}
        >>> sim = p2p.Simulation(implant)
        >>> resp = sim.pulse2percept(stim, implant)  # doctest: +SKIP
        """
        logging.getLogger(__name__).info("Starting pulse2percept...")

        # Get a flattened, all-uppercase list of layers
        layers = np.array([layers]).flatten()
        layers = np.array([l.upper() for l in layers])

        # Make sure all specified layers exist
        not_supported = np.array(
            [l not in retina.SUPPORTED_LAYERS for l in layers], dtype=bool)
        if any(not_supported):
            msg = ', '.join(layers[not_supported])
            msg = "Specified layer %s not supported. " % msg
            msg += "Choose from %s." % ', '.join(retina.SUPPORTED_LAYERS)
            raise ValueError(msg)

        # Set up all layers that haven't been set up yet
        self._set_layers()

        # Parse `stim` (either single pulse train or a list/dict of pulse
        # trains), and generate a list of pulse trains, one for each electrode
        pt_list = stimuli.parse_pulse_trains(stim, self.implant)
        pt_data = [pt.data for pt in pt_list]

        if not np.allclose([p.tsample for p in pt_list], self.gcl.tsample):
            e_s = "For now, all pulse trains must have the same sampling "
            e_s += "time step as the ganglion cell layer. In the future, "
            e_s += "this requirement might be relaxed."
            raise ValueError(e_s)

        # Tissue activation maps: If OFL is simulated, includes axon streaks.
        if 'OFL' in layers:
            ecs, _ = self.ofl.electrode_ecs(self.implant)
        else:
            _, ecs = self.ofl.electrode_ecs(self.implant)

        # Calculate the max of every current spread map
        lmax = np.zeros((2, ecs.shape[-1]))
        if 'INL' in layers:
            lmax[0, :] = ecs[:, :, 0, :].max(axis=(0, 1))
        if ('GCL' or 'OFL') in layers:
            lmax[1, :] = ecs[:, :, 1, :].max(axis=(0, 1))

        # `ecs_list` is a pixel by `n` list where `n` is the number of layers
        # being simulated. Each value in `ecs_list` is the current contributed
        # by each electrode for that spatial location
        ecs_list = []
        idx_list = []
        for xx in range(self.ofl.gridx.shape[1]):
            for yy in range(self.ofl.gridx.shape[0]):
                # If any of the used current spread maps at [yy, xx] are above
                # tolerance, we need to process that pixel
                process_pixel = False
                if 'INL' in layers:
                    # For this pixel: Check if the ecs in any layer is large
                    # enough compared to the max across pixels within the layer
                    process_pixel |= np.any(
                        ecs[yy, xx, 0, :] >= tol * lmax[0, :])
                if ('GCL' or 'OFL') in layers:
                    process_pixel |= np.any(
                        ecs[yy, xx, 1, :] >= tol * lmax[1, :])

                if process_pixel:
                    ecs_list.append(ecs[yy, xx])
                    idx_list.append([yy, xx])

        s_info = "tol=%.1f%%, %d/%d px selected" % (tol * 100, len(ecs_list),
                                                    np.prod(ecs.shape[:2]))
        logging.getLogger(__name__).info(s_info)

        sr_list = utils.parfor(self.gcl.model_cascade,
                               ecs_list,
                               n_jobs=self.n_jobs,
                               engine=self.engine,
                               scheduler=self.scheduler,
                               func_args=[pt_data, layers, self.use_jit])
        bm = np.zeros(self.ofl.gridx.shape + (sr_list[0].data.shape[-1], ))
        idxer = tuple(np.array(idx_list)[:, i] for i in range(2))
        bm[idxer] = [sr.data for sr in sr_list]
        percept = utils.TimeSeries(sr_list[0].tsample, bm)

        # It is possible to specify an additional sampling rate for the
        # percept: If different from the input sampling rate, need to resample.
        if t_percept != percept.tsample:
            percept = percept.resample(t_percept)

        logging.getLogger(__name__).info("Done.")

        return percept