예제 #1
0
def test_parfor():
    my_array = np.arange(100).reshape(10, 10)
    i, j = np.random.randint(0, 9, 2)
    my_list = list(my_array.ravel())
    npt.assert_equal(utils.parfor(power_it, my_list,
                                  out_shape=my_array.shape)[i, j],
                     power_it(my_array[i, j]))

    # If it's not reshaped, the first item should be the item 0, 0:
    npt.assert_equal(utils.parfor(power_it, my_list)[0],
                     power_it(my_array[0, 0]))
예제 #2
0
def plot_fundus(ax, subject, subjectdata, n_bundles=100, upside_down=False,
                annot_array=True, annot_quadr=True):
    """Plot an implant on top of the axon map

    This function plots an electrode array on top of the axon map, akin to a
    fundus photograph. Implant location should be given via `subjectdata`.

    Parameters
    ----------
    ax : matplotlib.axes._subplots.AxesSubplot, optional, default: None
        A Matplotlib axes object. If None given, a new one will be created.
    subject : str
        Subject ID, must be a valid value for column 'subject' in
        `subjectdata`.
    subjectdata : pd.DataFrame
        DataFrame with Subject ID as index. Must have columns 'implant_x',
        'implant_y', 'implant_rot', 'implant_type', 'eye', 'loc_od_x',
        'loc_od_y'.
    n_bundles : int, optional, default: 100
        Number of nerve fiber bundles to plot.
    upside_down : bool, optional, default: False
        Flag whether to plot the retina upside-down, such that the upper
        half of the plot corresponds to the upper visual field. In general,
        inferior retina == upper visual field (and superior == lower).
    annot_array : bool, optional, default: True
        Flag whether to label electrodes and the tack.
    annot_quadr : bool, optional, default: True
        Flag whether to annotate the four retinal quadrants
        (inferior/superior x temporal/nasal).

    """
    for col in ['implant_x', 'implant_y', 'implant_rot', 'implant_type',
                'eye', 'loc_od_x', 'loc_od_y']:
        if col not in subjectdata.columns:
            raise ValueError('subjectdata must contain column "%s".' % col)
    if subject not in subjectdata.index:
        raise ValueError('Subject "%s" not an index in subjectdata.' % subject)
    if n_bundles < 1:
        raise ValueError('Number of nerve fiber bundles must be >= 1.')

    # Choose the appropriate image / electrode locations based on implant type:
    implant_type = subjectdata.loc[subject, 'implant_type']
    implant = implant_type(x_center=subjectdata.loc[subject, 'implant_x'],
                           y_center=subjectdata.loc[subject, 'implant_y'],
                           rot=subjectdata.loc[subject, 'implant_rot'],
                           eye=subjectdata.loc[subject, 'eye'])
    loc_od = tuple(subjectdata.loc[subject, ['loc_od_x', 'loc_od_y']])

    phi_range = (-180.0, 180.0)
    n_rho = 801
    rho_range = (2.0, 45.0)

    # Make sure x-coord of optic disc has the correct sign for LE/RE:
    if (implant.eye == 'RE' and loc_od[0] <= 0 or
            implant.eye == 'LE' and loc_od[0] > 0):
        logstr = ("For eye==%s, expected opposite sign of x-coordinate of "
                  "the optic disc; changing %.2f to %.2f" % (implant.eye,
                                                             loc_od[0],
                                                             -loc_od[0]))
        print(logstr)
        loc_od = (-loc_od[0], loc_od[1])
    if ax is None:
        # No axes object given: create
        fig, ax = plt.subplots(1, figsize=(10, 8))
    else:
        fig = ax.figure

    # Matplotlib<2 compatibility
    if hasattr(ax, 'set_facecolor'):
        ax.set_facecolor('black')
    elif hasattr(ax, 'set_axis_bgcolor'):
        ax.set_axis_bgcolor('black')

    # Draw axon pathways:
    phi = np.linspace(*phi_range, num=n_bundles)
    func_kwargs = {'n_rho': n_rho, 'loc_od': loc_od,
                   'rho_range': rho_range, 'eye': implant.eye}
    axon_bundles = p2pu.parfor(p2pr.jansonius2009, phi,
                               func_kwargs=func_kwargs)
    for bundle in axon_bundles:
        ax.plot(p2pr.dva2ret(bundle[:, 0]), p2pr.dva2ret(bundle[:, 1]),
                c=(0.5, 1.0, 0.5))

    # Plot all electrodes and label them (optional):
    for e in implant.electrodes:
        if annot_array:
            ax.text(e.x_center + 100, e.y_center + 50, e.name,
                    color='white', size='x-large')
        ax.plot(e.x_center, e.y_center, 'ow', markersize=np.sqrt(e.radius))

    # Plot the location of the array's tack and annotate it (optional):
    if implant.tack:
        tx, ty = implant.tack
        ax.plot(tx, ty, 'ow')
        if annot_array:
            if upside_down:
                offset = 100
            else:
                offset = -100
            ax.text(tx, ty + offset, 'tack',
                    horizontalalignment='center',
                    verticalalignment='top',
                    color='white', size='large')

    # Show circular optic disc:
    ax.add_patch(patches.Circle(p2pr.dva2ret(loc_od), radius=900, alpha=1,
                                color='black', zorder=10))

    xmin, xmax, ymin, ymax = p2pr.dva2ret([-20, 20, -15, 15])
    ax.set_aspect('equal')
    ax.set_xlim(xmin, xmax)
    ax.set_xlabel('x (microns)')
    ax.set_ylim(ymin, ymax)
    ax.set_ylabel('y (microns)')
    eyestr = {'LE': 'left', 'RE': 'right'}
    ax.set_title('%s in %s eye' % (implant, eyestr[implant.eye]))
    ax.grid('off')

    # Annotate the four retinal quadrants near the corners of the plot:
    # superior/inferior x temporal/nasal
    if annot_quadr:
        if upside_down:
            topbottom = ['bottom', 'top']
        else:
            topbottom = ['top', 'bottom']
        if implant.eye == 'RE':
            temporalnasal = ['temporal', 'nasal']
        else:
            temporalnasal = ['nasal', 'temporal']
        for yy, valign, si in zip([ymax, ymin], topbottom,
                                  ['superior', 'inferior']):
            for xx, halign, tn in zip([xmin, xmax], ['left', 'right'],
                                      temporalnasal):
                ax.text(xx, yy, si + ' ' + tn,
                        color='black', fontsize=14,
                        horizontalalignment=halign,
                        verticalalignment=valign,
                        backgroundcolor=(1, 1, 1, 0.8))

    # Need to flip y axis to have upper half == upper visual field
    if upside_down:
        ax.invert_yaxis()

    return fig, ax
def pulse2percept(stim, implant, tm=None, retina=None,
                  rsample=30, scale_charge=42.1, tol=0.05, use_ecs=True,
                  engine='joblib', dojit=True, n_jobs=-1, verbose=True):
    """Transforms an input stimulus to a percept

    This function passes an input stimulus `stim` to a retinal `implant`,
    which is placed on a simulated `retina`, and produces a predicted percept
    by means of the temporal model `tm`.

    Parameters
    ----------
    stim : utils.TimeSeries|list|dict
        There are several ways to specify an input stimulus:
        - For a single-electrode array, pass a single pulse train; i.e., a
          single utils.TimeSeries object.
        - For a multi-electrode array, pass a list of pulse trains; i.e., one
          pulse train per electrode.
        - For a multi-electrode array, specify all electrodes that should
          receive non-zero pulse trains by name.
    implant : e2cm.ElectrodeArray
        An ElectrodeArray object that describes the implant.
    tm : ec2b.TemporalModel
        A model of temporal sensitivity.
    retina : e2cm.Retina
        A Retina object specyfing the geometry of the retina.
    rsample : int
        Resampling factor. For example, a resampling factor of 3 keeps
        only every third frame.
        Default: 30 frames per second.
    scale_charge : float
        Scaling factor applied to charge accumulation (used to be called
        epsilon). Default: 42.1.
    tol : float
        Ignore pixels whose effective current is smaller than tol.
        Default: 0.05.
    use_ecs : bool
        Flag whether to use effective current spread (True) or regular
        current spread, unaffected by axon pathways (False).
        Default: True.
    engine : str
        Which computational backend to use:
        - 'serial': Single-core computation
        - 'joblib': Parallelization via joblib (requires `pip install joblib`)
        - 'dask': Parallelization via dask (requires `pip install dask`)
        Default: joblib.
    dojit : bool
        Whether to use just-in-time (JIT) compilation to speed up computation.
        Default: True.
    n_jobs : int
        Number of cores (threads) to run the model on in parallel. Specify -1
        to use as many cores as possible.
        Default: -1.
    verbose : bool
        Flag whether to produce output (True) or suppress it (False).
        Default: True.

    Returns
    -------
    A brightness movie depicting the predicted percept, running at `rsample`
    frames per second.

    Examples
    --------
    Stimulate a single-electrode array:
    >>> implant = e2cm.ElectrodeArray('subretinal', 0, 0, 0, 0)
    >>> stim = e2cm.Psycho2Pulsetrain(tsample=5e-6, freq=50, amp=20)
    >>> resp = pulse2percept(stim, implant, verbose=False)

    Stimulate a single electrode ('C3') of an Argus I array centered on the
    fovea:
    >>> implant = e2cm.ArgusI()
    >>> stim = {'C3': e2cm.Psycho2Pulsetrain(tsample=5e-6, freq=50, amp=20)}
    >>> resp = pulse2percept(stim, implant) # doctest: +SKIP
    """
    # Check type to avoid backwards compatibility issues
    if not isinstance(implant, e2cm.ElectrodeArray):
        raise TypeError("`implant` must be of type ec2b.ElectrodeArray")

    # Parse `stim` (either single pulse train or a list/dict of pulse trains),
    # and generate a list of pulse trains, one for each electrode
    pt_list = parse_pulse_trains(stim, implant)

    # Generate a standard temporal model if necessary
    if tm is None:
        tm = TemporalModel(pt_list[0].tsample)
    elif not isinstance(tm, TemporalModel):
        raise TypeError("`tm` must be of type ec2b.TemporalModel")

    # Generate a retina if necessary
    if retina is None:
        # Make sure implant fits on retina
        round_to = 500  # round to nearest (microns)
        cspread = 500  # expected current spread (microns)
        xs = [a.x_center for a in implant]
        ys = [a.y_center for a in implant]
        xlo = np.floor((np.min(xs) - cspread) / round_to) * round_to
        xhi = np.ceil((np.max(xs) + cspread) / round_to) * round_to
        ylo = np.floor((np.min(ys) - cspread) / round_to) * round_to
        yhi = np.ceil((np.max(ys) + cspread) / round_to) * round_to
        retina = e2cm.Retina(xlo=xlo, xhi=xhi, ylo=ylo, yhi=yhi,
                             save_data=False, verbose=verbose)
    elif not isinstance(retina, e2cm.Retina):
        raise TypeError("`retina` object must be of type e2cm.Retina")

    # Derive the effective current spread
    if use_ecs:
        ecs, _ = retina.electrode_ecs(implant)
    else:
        _, ecs = retina.electrode_ecs(implant)

    # `ecs_list` is a pixel by `n` list where `n` is the number of layers
    # being simulated. Each value in `ecs_list` is the current contributed
    # by each electrode for that spatial location
    ecs_list = []
    idx_list = []
    for xx in range(retina.gridx.shape[1]):
        for yy in range(retina.gridx.shape[0]):
            if np.all(ecs[yy, xx] < tol):
                pass
            else:
                ecs_list.append(ecs[yy, xx])
                idx_list.append([yy, xx])

    # Apply charge accumulation
    for i, p in enumerate(pt_list):
        ca = tm.tsample * np.cumsum(np.maximum(0, -p.data))
        tmp = fftconvolve(ca, tm.gamma_ca, mode='full')
        conv_ca = scale_charge * tm.tsample * tmp[:p.data.size]

        # negative elements first
        idx = np.where(p.data <= 0)[0]
        pt_list[i].data[idx] = np.minimum(p.data[idx] + conv_ca[idx], 0)

        # then positive elements
        idx = np.where(p.data > 0)[0]
        pt_list[i].data[idx] = np.maximum(p.data[idx] - conv_ca[idx], 0)
    pt_arr = np.array([p.data for p in pt_list])

    # Which layer to simulate is given by implant type
    if implant.etype == 'epiretinal':
        dolayer = 'NFL'  # nerve fiber layer
    elif implant.etype == 'subretinal':
        dolayer = 'INL'  # inner nuclear layer
    else:
        e_s = "Supported electrode types are 'epiretinal', 'subretinal'"
        raise ValueError(e_s)

    sr_list = utils.parfor(calc_pixel, ecs_list, n_jobs=n_jobs, engine=engine,
                           func_args=[pt_arr, tm, rsample, dolayer, dojit])
    bm = np.zeros(retina.gridx.shape + (sr_list[0].data.shape[-1], ))
    idxer = tuple(np.array(idx_list)[:, i] for i in range(2))
    bm[idxer] = [sr.data for sr in sr_list]
    return utils.TimeSeries(sr_list[0].tsample, bm)
예제 #4
0
    def pulse2percept(self, stim, t_percept=None, tol=0.05,
                      layers=['OFL', 'GCL', 'INL']):
        """Transforms an input stimulus to a percept

        Parameters
        ----------
        stim : utils.TimeSeries|list|dict
            There are several ways to specify an input stimulus:

            - For a single-electrode array, pass a single pulse train; i.e.,
              a single utils.TimeSeries object.
            - For a multi-electrode array, pass a list of pulse trains; i.e.,
              one pulse train per electrode.
            - For a multi-electrode array, specify all electrodes that should
              receive non-zero pulse trains by name.
        t_percept : float, optional, default: inherit from `stim` object
            The desired time sampling of the output (seconds).
        tol : float, optional, default: 0.05
            Ignore pixels whose effective current is smaller than a fraction
            `tol` of the max value.
        layers : list, optional, default: ['OFL', 'GCL', 'INL']
            A list of retina layers to simulate (order does not matter):
            - 'OFL': Includes the optic fiber layer in the simulation.
                     If omitted, the tissue activation map will not account
                     for axon streaks.
            - 'GCL': Includes the ganglion cell layer in the simulation.
            - 'INL': Includes the inner nuclear layer in the simulation.
                     If omitted, bipolar cell activity does not contribute
                     to ganglion cell activity.

        Returns
        -------
        A utils.TimeSeries object whose data container comprises the predicted
        brightness over time at each retinal location (x, y), with the last
        dimension of the container representing time (t).

        Examples
        --------
        Simulate a single-electrode array:

        >>> import pulse2percept as p2p
        >>> implant = p2p.implants.ElectrodeArray('subretinal', 0, 0, 0)
        >>> stim = p2p.stimuli.PulseTrain(tsample=5e-6, freq=50, amp=20)
        >>> sim = p2p.Simulation(implant)
        >>> percept = sim.pulse2percept(stim)  # doctest: +SKIP

        Simulate an Argus I array centered on the fovea, where a single
        electrode is being stimulated ('C3'):

        >>> import pulse2percept as p2p
        >>> implant = p2p.implants.ArgusI()
        >>> stim = {'C3': stimuli.PulseTrain(tsample=5e-6, freq=50,
        ...                                              amp=20)}
        >>> sim = p2p.Simulation(implant)
        >>> resp = sim.pulse2percept(stim, implant)  # doctest: +SKIP
        """
        logging.getLogger(__name__).info("Starting pulse2percept...")

        # Get a flattened, all-uppercase list of layers
        layers = np.array([layers]).flatten()
        layers = np.array([l.upper() for l in layers])

        # Make sure all specified layers exist
        not_supported = np.array([l not in retina.SUPPORTED_LAYERS
                                  for l in layers], dtype=bool)
        if any(not_supported):
            msg = ', '.join(layers[not_supported])
            msg = "Specified layer %s not supported. " % msg
            msg += "Choose from %s." % ', '.join(retina.SUPPORTED_LAYERS)
            raise ValueError(msg)

        # Set up all layers that haven't been set up yet
        self._set_layers()

        # Parse `stim` (either single pulse train or a list/dict of pulse
        # trains), and generate a list of pulse trains, one for each electrode
        pt_list = stimuli.parse_pulse_trains(stim, self.implant)
        pt_data = [pt.data for pt in pt_list]

        if not np.allclose([p.tsample for p in pt_list], self.gcl.tsample):
            e_s = "For now, all pulse trains must have the same sampling "
            e_s += "time step as the ganglion cell layer. In the future, "
            e_s += "this requirement might be relaxed."
            raise ValueError(e_s)

        # Tissue activation maps: If OFL is simulated, includes axon streaks.
        if 'OFL' in layers:
            ecs, _ = self.ofl.electrode_ecs(self.implant)
        else:
            _, ecs = self.ofl.electrode_ecs(self.implant)

        # Calculate the max of every current spread map
        lmax = np.zeros((2, ecs.shape[-1]))
        if 'INL' in layers:
            lmax[0, :] = ecs[:, :, 0, :].max(axis=(0, 1))
        if ('GCL' or 'OFL') in layers:
            lmax[1, :] = ecs[:, :, 1, :].max(axis=(0, 1))

        # `ecs_list` is a pixel by `n` list where `n` is the number of layers
        # being simulated. Each value in `ecs_list` is the current contributed
        # by each electrode for that spatial location
        ecs_list = []
        idx_list = []
        for xx in range(self.ofl.gridx.shape[1]):
            for yy in range(self.ofl.gridx.shape[0]):
                # If any of the used current spread maps at [yy, xx] are above
                # tolerance, we need to process that pixel
                process_pixel = False
                if 'INL' in layers:
                    # For this pixel: Check if the ecs in any layer is large
                    # enough compared to the max across pixels within the layer
                    process_pixel |= np.any(ecs[yy, xx, 0, :] >=
                                            tol * lmax[0, :])
                if ('GCL' or 'OFL') in layers:
                    process_pixel |= np.any(ecs[yy, xx, 1, :] >=
                                            tol * lmax[1, :])

                if process_pixel:
                    ecs_list.append(ecs[yy, xx])
                    idx_list.append([yy, xx])

        s_info = "tol=%.1f%%, %d/%d px selected" % (tol * 100, len(ecs_list),
                                                    np.prod(ecs.shape[:2]))
        logging.getLogger(__name__).info(s_info)

        sr_list = utils.parfor(self.gcl.model_cascade,
                               ecs_list, n_jobs=self.num_jobs,
                               engine=self.engine, scheduler=self.scheduler,
                               func_args=[pt_data, layers, self.use_jit])
        bm = np.zeros(self.ofl.gridx.shape +
                      (sr_list[0].data.shape[-1], ))
        idxer = tuple(np.array(idx_list)[:, i] for i in range(2))
        bm[idxer] = [sr.data for sr in sr_list]
        percept = utils.TimeSeries(sr_list[0].tsample, bm)

        # It is possible to specify an additional sampling rate for the
        # percept: If different from the input sampling rate, need to resample.
        if t_percept != percept.tsample:
            percept = percept.resample(t_percept)

        logging.getLogger(__name__).info("Done.")

        return percept