Esempio n. 1
0
def test_TemporalModel():
    # Build grid:
    model = ValidTemporalModel()
    npt.assert_equal(model.is_built, False)
    model.build()
    npt.assert_equal(model.is_built, True)

    # Can overwrite default values:
    model = ValidTemporalModel(dt=2e-5)
    npt.assert_almost_equal(model.dt, 2e-5)
    model.build(dt=1.234)
    npt.assert_almost_equal(model.dt, 1.234)

    # Cannot add more attributes:
    with pytest.raises(AttributeError):
        ValidTemporalModel(newparam=1)
    with pytest.raises(FreezeError):
        model.newparam = 1

    # Returns Percept object of proper size:
    npt.assert_equal(model.predict_percept(ArgusI().stim), None)
    model.dt = 1
    for stim in [
            np.ones((16, 3)),
            np.zeros((16, 3)), {
                'A1': [1, 2]
            },
            np.ones((16, 2))
    ]:
        implant = ArgusI(stim=stim)
        percept = model.predict_percept(implant.stim)
        # By default, percept is output every 20ms. If stimulus is too short,
        # output at t=[0, 20]. This is mentioned in the docs - for really short
        # stimuli, users should specify the desired time points manually.
        n_time = 1 if implant.stim.time is None else 2
        npt.assert_equal(percept.shape, (implant.stim.shape[0], 1, n_time))
        npt.assert_almost_equal(percept.data, 0)

    # t_percept is automatically sorted:
    model.dt = 0.1
    percept = model.predict_percept(Stimulus(np.zeros((3, 17))),
                                    t_percept=[0.1, 0.8, 0.6])
    npt.assert_almost_equal(percept.time, [0.1, 0.6, 0.8])

    # Invalid calls:
    with pytest.raises(ValueError):
        # Cannot request t_percepts that are not multiples of dt:
        model.predict_percept(Stimulus(np.ones((3, 9))), t_percept=[0.1, 0.11])
    with pytest.raises(ValueError):
        # Has temporal model but stim.time is None:
        ValidTemporalModel().predict_percept(Stimulus(3))
    with pytest.raises(ValueError):
        # stim.time==None but requesting t_percept != None
        ValidTemporalModel().predict_percept(Stimulus(3), t_percept=[0, 1, 2])
    with pytest.raises(NotBuiltError):
        # Must call build first:
        ValidTemporalModel().predict_percept(Stimulus(3))
    with pytest.raises(TypeError):
        # Must pass a stimulus:
        ValidTemporalModel().build().predict_percept(ArgusI())
def test_PulseTrain():
    # All zeros:
    npt.assert_almost_equal(PulseTrain(10, Stimulus(np.zeros((1, 5)))).data,
                            0)
    # Simple fake pulse:
    pulse = Stimulus([[0, -1, 0]], time=[0, 0.1, 0.2])
    for n_pulses in [2, 3, 10]:
        pt = PulseTrain(10, pulse, n_pulses=n_pulses, electrode='A4')
        npt.assert_equal(np.sum(np.isclose(pt.data, -1)), n_pulses)
        npt.assert_equal(pt.electrodes, 'A4')

    # PulseTrains can cut off/trim individual pulses if necessary:
    pt = PulseTrain(3, pulse, stim_dur=11)
    npt.assert_almost_equal(pt.time[-1], 11)
    npt.assert_almost_equal(pt[0, 11], 0)

    # Invalid calls:
    with pytest.raises(TypeError):
        # Wrong stimulus type:
        PulseTrain(10, {'a': 1})
    with pytest.raises(ValueError):
        # Pulse does not fit:
        PulseTrain(100000, pulse)
    with pytest.raises(ValueError):
        # n_pulses does not fit:
        PulseTrain(10, pulse, n_pulses=100000)
    with pytest.raises(ValueError):
        # No time component:
        PulseTrain(10, Stimulus(1))
    with pytest.raises(ValueError):
        # Empty stim:
        pulse = Stimulus([[0, 0, 0]], time=[0, 0.1, 0.2], compress=True)
        PulseTrain(10, pulse)
Esempio n. 3
0
def test_Model_predict_percept():
    # A None Model has nothing to build, nothing to perceive:
    model = Model()
    npt.assert_equal(model.predict_percept(ArgusI()), None)
    npt.assert_equal(model.predict_percept(ArgusI(stim={'A1': 1})), None)
    npt.assert_equal(
        model.predict_percept(ArgusI(stim={'A1': 1}), t_percept=[0, 1]), None)

    # Just the spatial model:
    model = Model(spatial=ValidSpatialModel()).build()
    npt.assert_equal(model.predict_percept(ArgusI()), None)
    # Just the temporal model:
    model = Model(temporal=ValidTemporalModel()).build()
    npt.assert_equal(model.predict_percept(ArgusI()), None)
    # Both spatial and temporal:

    # Invalid calls:
    model = Model(spatial=ValidSpatialModel(), temporal=ValidTemporalModel())
    with pytest.raises(NotBuiltError):
        # Must call build first:
        model.predict_percept(ArgusI())
    model.build()
    with pytest.raises(ValueError):
        # Cannot request t_percepts that are not multiples of dt:
        model.predict_percept(ArgusI(stim={'A1': np.ones(16)}),
                              t_percept=[0.1, 0.11])
    with pytest.raises(ValueError):
        # Has temporal model but stim.time is None:
        ValidTemporalModel().predict_percept(Stimulus(3))
    with pytest.raises(ValueError):
        # stim.time==None but requesting t_percept != None
        model.predict_percept(ArgusI(stim=np.ones(16)), t_percept=[0, 1, 2])
    with pytest.raises(TypeError):
        # Must pass an implant:
        model.predict_percept(Stimulus(3))
Esempio n. 4
0
def test_Stimulus_remove():
    stim = Stimulus([[0, 1, 2], [3, 4, 5]], electrodes=['A1', 'C3'])
    npt.assert_equal('A1' in stim.electrodes, True)
    npt.assert_equal('C3' in stim.electrodes, True)
    stim.remove('A1')
    npt.assert_equal('A1' in stim.electrodes, False)
    npt.assert_equal('C3' in stim.electrodes, True)
Esempio n. 5
0
def test_PulseTrain():
    # All zeros:
    npt.assert_almost_equal(PulseTrain(10, Stimulus(np.zeros((1, 5)))).data, 0)
    # Simple fake pulse:
    pulse = Stimulus([[0, -1, 0]], time=[0, 0.1, 0.2])
    for n_pulses in [2, 3, 10]:
        pt = PulseTrain(10, pulse, n_pulses=n_pulses)
        npt.assert_equal(np.sum(np.isclose(pt.data, -1)), n_pulses)

    # stim_dur too short:
    npt.assert_almost_equal(PulseTrain(2, pulse, stim_dur=10).data, 0)

    # Invalid calls:
    with pytest.raises(TypeError):
        # Wrong stimulus type:
        PulseTrain(10, {'a': 1})
    with pytest.raises(ValueError):
        # Pulse does not fit:
        PulseTrain(100000, pulse)
    with pytest.raises(ValueError):
        # n_pulses does not fit:
        PulseTrain(10, pulse, n_pulses=100000)
    with pytest.raises(ValueError):
        # No time component:
        PulseTrain(10, Stimulus(1))
    with pytest.raises(ValueError):
        # Empty stim:
        pulse = Stimulus([[0, 0, 0]], time=[0, 0.1, 0.2], compress=True)
        PulseTrain(10, pulse)
Esempio n. 6
0
def test_MonophasicPulse(amp, delay_dur):
    phase_dur = 3.456
    # Basic usage:
    pulse = MonophasicPulse(amp, phase_dur, delay_dur=delay_dur)
    npt.assert_almost_equal(pulse[0, 0], 0)
    npt.assert_almost_equal(pulse[0, delay_dur + phase_dur / 2.0], amp)
    npt.assert_almost_equal(pulse.time[0], 0)
    npt.assert_almost_equal(pulse.time[-1], phase_dur + delay_dur)
    npt.assert_equal(pulse.cathodic, amp <= 0)
    npt.assert_equal(pulse.charge_balanced, False)

    # Custom stim dur:
    pulse = MonophasicPulse(amp, phase_dur, delay_dur=delay_dur, stim_dur=100)
    npt.assert_almost_equal(pulse[0, 0], 0)
    npt.assert_almost_equal(pulse[0, delay_dur + phase_dur / 2.0], amp)
    npt.assert_almost_equal(pulse.time[0], 0)
    npt.assert_almost_equal(pulse.time[-1], 100)

    # Zero amplitude:
    pulse = MonophasicPulse(0, phase_dur, delay_dur=delay_dur, electrode='A1')
    npt.assert_almost_equal(pulse.data, 0)
    npt.assert_almost_equal(pulse.time[0], 0)
    npt.assert_almost_equal(pulse.time[-1], phase_dur + delay_dur)
    npt.assert_equal(pulse.charge_balanced, True)
    npt.assert_equal(pulse.electrodes, 'A1')

    # You can wrap a pulse in a Stimulus to overwrite attributes:
    stim = Stimulus(pulse, electrodes='AA1')
    npt.assert_equal(stim.electrodes, 'AA1')
    # Or concatenate:
    stim = Stimulus([pulse, pulse])
    npt.assert_equal(stim.shape[0], 2)
    npt.assert_almost_equal(stim.data[0, :], stim.data[1, :])
    npt.assert_almost_equal(stim.time, pulse.time)
    npt.assert_equal(stim.electrodes, ['A1', 1])
    # Concatenate and rename:
    stim = Stimulus([pulse, pulse], electrodes=['C1', 'D2'])
    npt.assert_equal(stim.electrodes, ['C1', 'D2'])

    # Invalid calls:
    with pytest.raises(ValueError):
        MonophasicPulse(amp, 0)
    with pytest.raises(ValueError):
        MonophasicPulse(amp, phase_dur, delay_dur=-1)
    with pytest.raises(ValueError):
        MonophasicPulse(amp, phase_dur, delay_dur=delay_dur, stim_dur=1)
    with pytest.raises(ValueError):
        MonophasicPulse(amp,
                        phase_dur,
                        delay_dur=delay_dur,
                        electrode=['A1', 'B2'])
Esempio n. 7
0
def test_Stimulus__stim():
    stim = Stimulus(3)
    # User could try and motify the data container after the constructor, which
    # would lead to inconsistencies between data, electrodes, time. The new
    # property setting mechanism prevents that.
    # Requires dict:
    with pytest.raises(AttributeError):
        stim._stim = np.array([0, 1])
    # Dict must have all required fields:
    fields = ['data', 'electrodes', 'time']
    for field in fields:
        _fields = deepcopy(fields)
        _fields.remove(field)
        with pytest.raises(AttributeError):
            stim._stim = {f: None for f in _fields}
    # Data must be a 2-D NumPy array:
    data = {f: None for f in fields}
    with pytest.raises(ValueError):
        data['data'] = np.ones(3)
        stim._stim = data
    # Data rows must match electrodes:
    with pytest.raises(ValueError):
        data['data'] = np.ones((3, 4))
        data['time'] = np.arange(4)
        data['electrodes'] = np.arange(2)
        stim._stim = data
    # Data columns must match time:
    with pytest.raises(ValueError):
        data['data'] = np.ones((3, 4))
        data['electrodes'] = np.arange(3)
        data['time'] = np.arange(7)
        stim._stim = data
    # Time points must be unique:
    assert_warns_msg(UserWarning, _unique_timepoints, None, stim, data)
    # But if you do all the things right, you can reset the stimulus by hand:
    data['data'] = np.ones((3, 1))
    data['electrodes'] = np.arange(3)
    data['time'] = None
    stim._stim = data

    data['data'] = np.ones((3, 1))
    data['electrodes'] = np.arange(3)
    data['time'] = np.arange(1)
    stim._stim = data

    data['data'] = np.ones((3, 4))
    data['electrodes'] = np.arange(3)
    data['time'] = np.array([0, 1, 1 + DT, 2])
    stim._stim = data
Esempio n. 8
0
def test_ProsthesisSystem():
    # Invalid instantiations:
    with pytest.raises(ValueError):
        ProsthesisSystem(ElectrodeArray(PointSource(0, 0, 0)), eye='both')

    # Iterating over the electrode array:
    implant = ProsthesisSystem(PointSource(0, 0, 0))
    npt.assert_equal(implant.n_electrodes, 1)
    npt.assert_equal(implant[0], implant.earray[0])
    npt.assert_equal(implant.keys(), implant.earray.keys())

    # Set a stimulus after the constructor:
    npt.assert_equal(implant.stim, None)
    implant.stim = 3
    npt.assert_equal(isinstance(implant.stim, Stimulus), True)
    npt.assert_equal(implant.stim.shape, (1, 1))
    npt.assert_equal(implant.stim.time, None)
    npt.assert_equal(implant.stim.electrodes, [0])

    with pytest.raises(ValueError):
        # Wrong number of stimuli
        implant.stim = [1, 2]
    with pytest.raises(TypeError):
        # Invalid stim type:
        implant.stim = "stim"
    # Invalid electrode names:
    with pytest.raises(ValueError):
        implant.stim = {'A1': 1}
    with pytest.raises(ValueError):
        implant.stim = Stimulus({'A1': 1})

    # Slots:
    npt.assert_equal(hasattr(implant, '__slots__'), True)
    npt.assert_equal(hasattr(implant, '__dict__'), False)
def test_predict_batched(engine):
    if not has_jax:
        pytest.skip("Jax not installed")

    # Allows mix of valid Stimulus types
    stims = [{
        'A5': BiphasicPulseTrain(25, 4, 0.45),
        'C7': BiphasicPulseTrain(50, 2.5, 0.75)
    }, {
        'B4': BiphasicPulseTrain(3, 1, 0.32)
    },
             Stimulus({'F3': BiphasicPulseTrain(12, 3, 1.2)})]
    implant = ArgusII()
    model = BiphasicAxonMapModel(engine=engine, xystep=2)
    model.build()
    # Import error if we dont have jax
    if engine != 'jax':
        with pytest.raises(ImportError):
            model.predict_percept_batched(implant, stims)
        return

    percepts_batched = model.predict_percept_batched(implant, stims)
    percepts_serial = []
    for stim in stims:
        implant.stim = stim
        percepts_serial.append(model.predict_percept(implant))

    npt.assert_equal(len(percepts_serial), len(percepts_batched))
    for p1, p2 in zip(percepts_batched, percepts_serial):
        npt.assert_almost_equal(p1.data, p2.data)
Esempio n. 10
0
def test_Stimulus_merge():
    # We can stack multiple stimuli together - their time axes will be merged:
    stim1 = Stimulus([[0, 1, 2, 3, 4]], time=[0, 1, 2, 3, 4])
    stim2 = Stimulus([[0, 1, 2]], time=[-0.5, 1.5, 4.5])
    merge = Stimulus([stim1, stim2])
    npt.assert_almost_equal(merge.time,
                            np.unique(np.hstack((stim1.time, stim2.time))),
                            decimal=6)
    npt.assert_almost_equal(merge[0, [0, -1]], stim1[0, [0, -1]])
    npt.assert_almost_equal(merge[1, [0, -1]], stim2[0, [0, -1]])

    # We can keep stacking - even when nested:
    stim3 = Stimulus([[14]], time=[9.7])
    merge2 = Stimulus([merge, stim3])
    npt.assert_almost_equal(merge2.time,
                            np.unique((np.hstack(
                                (stim1.time, stim2.time, stim3.time)))),
                            decimal=6)
    npt.assert_almost_equal(merge2[0, [0, -1]], stim1[0, [0, -1]])
    npt.assert_almost_equal(merge2[1, [0, -1]], stim2[0, [0, -1]])
    npt.assert_almost_equal(merge2[2, [0, -1]], stim3[0, [0, -1]])
Esempio n. 11
0
def test_merge_time_axes_merge_tolerance():
    # Test issue where not enough unique points were collected
    # Leading to interpolation to corrupt stimuli data.
    # See: https://github.com/pulse2percept/pulse2percept/issues/392
    a = BiphasicPulseTrain(20, 1, 0.45)
    b = BiphasicPulseTrain(30, 1, 0.45)

    stim = Stimulus({"A2": a, "A10": b})
    unique_points = np.unique(stim.data)

    # Assert no value goes close to 1/3 or -1/3, i.e. a corrupted data point
    npt.assert_equal(np.isclose(1 / 3, unique_points, atol=0.1).any(), False)
    npt.assert_equal(np.isclose(-1 / 3, unique_points, atol=0.1).any(), False)
Esempio n. 12
0
def test_Stimulus_arithmetic(scalar):
    stim = Stimulus([[0, 21, -13, 0, 0]], time=[0, 1, 2, 3, 4])
    npt.assert_almost_equal((stim + scalar).data,
                            stim.data + scalar,
                            decimal=5)
    npt.assert_almost_equal((scalar + stim).data,
                            scalar + stim.data,
                            decimal=5)
    npt.assert_almost_equal((stim - scalar).data,
                            stim.data - scalar,
                            decimal=5)
    npt.assert_almost_equal((scalar - stim).data,
                            scalar - stim.data,
                            decimal=5)
    npt.assert_almost_equal((stim * scalar).data,
                            stim.data * scalar,
                            decimal=5)
    npt.assert_almost_equal((scalar * stim).data,
                            scalar * stim.data,
                            decimal=5)
    npt.assert_almost_equal((stim / scalar).data,
                            stim.data / scalar,
                            decimal=5)
    npt.assert_almost_equal((-stim).data, -1 * stim.data, decimal=5)
    npt.assert_almost_equal((stim >> scalar).time,
                            stim.time + scalar,
                            decimal=5)
    npt.assert_almost_equal((stim << scalar).time,
                            stim.time - scalar,
                            decimal=5)
    # 10 / stim is not supported because it will always give a division by
    # zero error:
    with pytest.raises(TypeError):
        s = scalar / stim
    with pytest.raises(TypeError):
        s = stim + stim
    with pytest.raises(TypeError):
        s = stim - stim
    with pytest.raises(TypeError):
        s = stim * stim
    with pytest.raises(TypeError):
        s = stim / stim
    with pytest.raises(TypeError):
        s = stim + [1, 1]
    with pytest.raises(TypeError):
        s = stim * np.array([2, 3])
    with pytest.raises(TypeError):
        s = stim >> np.array([2, 3])
    with pytest.raises(TypeError):
        s = stim << np.array([2, 3])
Esempio n. 13
0
def test_ProsthesisSystem():
    # Invalid instantiations:
    with pytest.raises(ValueError):
        ProsthesisSystem(ElectrodeArray(PointSource(0, 0, 0)),
                         eye='both')
    with pytest.raises(TypeError):
        ProsthesisSystem(Stimulus)

    # Iterating over the electrode array:
    implant = ProsthesisSystem(PointSource(0, 0, 0))
    npt.assert_equal(implant.n_electrodes, 1)
    npt.assert_equal(implant[0], implant.earray[0])
    npt.assert_equal(implant.electrode_names, implant.earray.electrode_names)
    for i, e in zip(implant, implant.earray):
        npt.assert_equal(i, e)

    # Set a stimulus after the constructor:
    npt.assert_equal(implant.stim, None)
    implant.stim = 3
    npt.assert_equal(isinstance(implant.stim, Stimulus), True)
    npt.assert_equal(implant.stim.shape, (1, 1))
    npt.assert_equal(implant.stim.time, None)
    npt.assert_equal(implant.stim.electrodes, [0])

    ax = implant.plot()
    npt.assert_equal(len(ax.texts), 0)
    npt.assert_equal(len(ax.collections), 1)

    with pytest.raises(ValueError):
        # Wrong number of stimuli
        implant.stim = [1, 2]
    with pytest.raises(TypeError):
        # Invalid stim type:
        implant.stim = "stim"
    # Invalid electrode names:
    with pytest.raises(ValueError):
        implant.stim = {'A1': 1}
    with pytest.raises(ValueError):
        implant.stim = Stimulus({'A1': 1})
    # Safe mode requires charge-balanced pulses:
    with pytest.raises(ValueError):
        implant = ProsthesisSystem(PointSource(0, 0, 0), safe_mode=True)
        implant.stim = 1

    # Slots:
    npt.assert_equal(hasattr(implant, '__slots__'), True)
    npt.assert_equal(hasattr(implant, '__dict__'), False)
Esempio n. 14
0
def test_SpatialModel():
    # Build grid:
    model = ValidSpatialModel()
    npt.assert_equal(model.grid, None)
    npt.assert_equal(model.is_built, False)
    model.build()
    npt.assert_equal(model.is_built, True)
    npt.assert_equal(isinstance(model.grid, GridXY), True)
    npt.assert_equal(isinstance(model.grid.xret, np.ndarray), True)

    # Can overwrite default values:
    model = ValidSpatialModel(xystep=1.234)
    npt.assert_almost_equal(model.xystep, 1.234)
    model.build(xystep=2.345)
    npt.assert_almost_equal(model.xystep, 2.345)

    # Cannot add more attributes:
    with pytest.raises(AttributeError):
        ValidSpatialModel(newparam=1)
    with pytest.raises(FreezeError):
        model.newparam = 1

    # Returns Percept object of proper size:
    npt.assert_equal(model.predict_percept(ArgusI()), None)
    for stim in [np.ones(16), np.zeros(16), {'A1': 2}, np.ones((16, 2))]:
        implant = ArgusI(stim=stim)
        percept = model.predict_percept(implant)
        npt.assert_equal(isinstance(percept, Percept), True)
        n_time = 1 if implant.stim.time is None else len(implant.stim.time)
        npt.assert_equal(
            percept.shape,
            (model.grid.x.shape[0], model.grid.x.shape[1], n_time))
        npt.assert_almost_equal(percept.data, 0)

    # Invalid calls:
    with pytest.raises(ValueError):
        # stim.time==None but requesting t_percept != None
        implant.stim = np.ones(16)
        model.predict_percept(implant, t_percept=[0, 1, 2])
    with pytest.raises(NotBuiltError):
        # must call build first
        model = ValidSpatialModel()
        model.predict_percept(ArgusI())
    with pytest.raises(TypeError):
        # must pass an implant
        ValidSpatialModel().build().predict_percept(Stimulus(3))
Esempio n. 15
0
def test_ProsthesisSystem_stim():
    implant = ProsthesisSystem(ElectrodeGrid((13, 13), 20))
    stim = Stimulus(np.ones((13 * 13 + 1, 5)))
    with pytest.raises(ValueError):
        implant.stim = stim

    # Deactivated electrodes cannot receive stimuli:
    implant.deactivate('H4')
    npt.assert_equal(implant['H4'].activated, False)
    implant.stim = {'H4': 1}
    npt.assert_equal('H4' in implant.stim.electrodes, False)

    implant.deactivate('all')
    npt.assert_equal(not implant.stim.data, True)
    implant.activate('all')
    implant.stim = {'H4': 1}
    npt.assert_equal('H4' in implant.stim.electrodes, True)
Esempio n. 16
0
def test_FadingTemporal():
    model = FadingTemporal()
    # User can set their own params:
    model.dt = 0.1
    npt.assert_equal(model.dt, 0.1)
    model.build(dt=1e-4)
    npt.assert_equal(model.dt, 1e-4)
    # User cannot add more model parameters:
    with pytest.raises(FreezeError):
        model.rho = 100

    # Nothing in, None out:
    npt.assert_equal(model.predict_percept(None), None)

    # Zero in = zero out:
    stim = BiphasicPulse(0, 1)
    percept = model.predict_percept(stim, t_percept=[0, 1, 2])
    npt.assert_equal(isinstance(percept, Percept), True)
    npt.assert_equal(percept.shape, (1, 1, 3))
    npt.assert_almost_equal(percept.data, 0)

    # Can't request the same time more than once (this would break the Cython
    # loop, because `idx_frame` is incremented after a write; also doesn't
    # make much sense):
    with pytest.raises(ValueError):
        stim = Stimulus(np.ones((1, 100)))
        model.predict_percept(stim, t_percept=[0.2, 0.2])

    # Simple decay for single cathodic pulse:
    model = FadingTemporal(tau=1).build()
    stim = MonophasicPulse(-1, 1, stim_dur=10)
    percept = model.predict_percept(stim, np.arange(stim.duration))
    npt.assert_almost_equal(percept.data.ravel()[:3], [0, 0.633, 0.232],
                            decimal=3)
    npt.assert_almost_equal(percept.data.ravel()[-1], 0, decimal=3)

    # But all zeros for anodic pulse:
    stim = MonophasicPulse(1, 1, stim_dur=10)
    percept = model.predict_percept(stim, np.arange(stim.duration))
    npt.assert_almost_equal(percept.data, 0)

    # tau cannot be negative:
    with pytest.raises(ValueError):
        FadingTemporal(tau=-1).build()
def test_metadata():
    stim = BiphasicPulseTrain(10, 10, 1, metadata='userdata')
    npt.assert_equal(stim.metadata['user'], 'userdata')
    npt.assert_equal(stim.metadata['freq'], 10)
    npt.assert_equal(stim.metadata['amp'], 10)
    npt.assert_equal(stim.metadata['phase_dur'], 1)
    npt.assert_equal(stim.metadata['delay_dur'], 0)

    stim = Stimulus(
        {
            'A2': BiphasicPulseTrain(10, 10, 1, metadata='userdataA2'),
            'B1': BiphasicPulseTrain(11, 9, 2, metadata='userdataB1'),
            'C3': BiphasicPulseTrain(12, 8, 3, metadata='userdataC3')
        },
        metadata='stimulus_userdata')
    npt.assert_equal(stim.metadata['user'], 'stimulus_userdata')
    npt.assert_equal(stim.metadata['electrodes']['A2']['type'],
                     BiphasicPulseTrain)
    npt.assert_equal(stim.metadata['electrodes']['B1']['metadata']['freq'], 11)
    npt.assert_equal(stim.metadata['electrodes']['C3']['metadata']['user'],
                     'userdataC3')
Esempio n. 18
0
def test_ProsthesisSystem_stim():
    implant = ProsthesisSystem(ElectrodeGrid((13, 13), 20))
    stim = Stimulus(np.ones((13 * 13 + 1, 5)))
    with pytest.raises(ValueError):
        implant.stim = stim

    # color mapping
    stim = np.zeros((13 * 13, 5))
    stim[84, 0] = 1
    stim[98, 2] = 2
    implant.stim = stim
    plt.cla()
    ax = implant.plot(stim_cmap='hsv')
    plt.colorbar()
    npt.assert_equal(len(ax.collections), 1)
    npt.assert_equal(ax.collections[0].colorbar.vmax, 2)
    npt.assert_equal(ax.collections[0].cmap(ax.collections[0].norm(1)),
                     (0.0, 1.0, 0.9647031631761764, 1))
    # make sure default behaviour unchanged
    plt.cla()
    ax = implant.plot()
    plt.colorbar()
    npt.assert_equal(len(ax.collections), 1)
    npt.assert_equal(ax.collections[0].colorbar.vmax, 1)
    npt.assert_equal(ax.collections[0].cmap(ax.collections[0].norm(1)),
                     (0.993248, 0.906157, 0.143936, 1))

    # Deactivated electrodes cannot receive stimuli:
    implant.deactivate('H4')
    npt.assert_equal(implant['H4'].activated, False)
    implant.stim = {'H4': 1}
    npt.assert_equal('H4' in implant.stim.electrodes, False)

    implant.deactivate('all')
    npt.assert_equal(not implant.stim.data, True)
    implant.activate('all')
    implant.stim = {'H4': 1}
    npt.assert_equal('H4' in implant.stim.electrodes, True)
def test_biphasicAxonMapSpatial(engine):
    # Lambda cannot be too small:
    with pytest.raises(ValueError):
        BiphasicAxonMapSpatial(axlambda=9).build()

    model = BiphasicAxonMapModel(engine=engine, xystep=2).build()
    # Jax not implemented yet
    if engine == 'jax':
        with pytest.raises(NotImplementedError):
            implant = ArgusII()
            implant.stim = Stimulus({'A5': BiphasicPulseTrain(20, 1, 0.45)})
            percept = model.predict_percept(implant)
        return

    # Only accepts biphasic pulse trains with no delay dur
    implant = ArgusI(stim=np.ones(16))
    with pytest.raises(TypeError):
        model.predict_percept(implant)

    # Nothing in, None out:
    npt.assert_equal(model.predict_percept(ArgusI()), None)

    # Zero in = zero out:
    implant = ArgusI(stim=np.zeros(16))
    percept = model.predict_percept(implant)
    npt.assert_equal(isinstance(percept, Percept), True)
    npt.assert_equal(percept.shape, list(model.grid.x.shape) + [1])
    npt.assert_almost_equal(percept.data, 0)
    npt.assert_equal(percept.time, None)

    # Should be equal to axon map model if effects models return 1
    model = BiphasicAxonMapSpatial(engine=engine, xystep=2)

    def bright_model(freq, amp, pdur):
        return 1

    def size_model(freq, amp, pdur):
        return 1

    def streak_model(freq, amp, pdur):
        return 1

    model.bright_model = bright_model
    model.size_model = size_model
    model.streak_model = streak_model
    model.build()
    axon_map = AxonMapSpatial(xystep=2).build()
    implant = ArgusII()
    implant.stim = Stimulus({'A5': BiphasicPulseTrain(20, 1, 0.45)})
    percept = model.predict_percept(implant)
    percept_axon = axon_map.predict_percept(implant)
    npt.assert_almost_equal(percept.data[:, :, 0],
                            percept_axon.get_brightest_frame())

    # Effect models must be callable
    model = BiphasicAxonMapSpatial(engine=engine, xystep=2)
    model.bright_model = 1.0
    with pytest.raises(TypeError):
        model.build()

    # If t_percept is not specified, there should only be one frame
    model = BiphasicAxonMapSpatial(engine=engine, xystep=2)
    model.build()
    implant = ArgusII()
    implant.stim = Stimulus({'A5': BiphasicPulseTrain(20, 1, 0.45)})
    percept = model.predict_percept(implant)
    npt.assert_equal(percept.time is None, True)
    # If t_percept is specified, only first frame should have data
    # and the rest should be empty
    percept = model.predict_percept(implant, t_percept=[0, 1, 2, 5, 10])
    npt.assert_equal(len(percept.time), 5)
    npt.assert_equal(np.any(percept.data[:, :, 0]), True)
    npt.assert_equal(np.any(percept.data[:, :, 1:]), False)

    # Test that default models give expected values
    model = BiphasicAxonMapSpatial(engine=engine,
                                   rho=400,
                                   axlambda=600,
                                   xystep=1,
                                   xrange=(-20, 20),
                                   yrange=(-15, 15))
    model.build()
    implant = ArgusII()
    implant.stim = Stimulus({'A4': BiphasicPulseTrain(20, 1, 1)})
    percept = model.predict_percept(implant)
    npt.assert_equal(np.sum(percept.data > 0.0813), 81)
    npt.assert_equal(np.sum(percept.data > 0.1626), 59)
    npt.assert_equal(np.sum(percept.data > 0.2439), 44)
    npt.assert_equal(np.sum(percept.data > 0.4065), 26)
    npt.assert_equal(np.sum(percept.data > 0.5691), 14)
Esempio n. 20
0
pt.plot()

###############################################################################
# Generic pulse trains
# --------------------
#
# Finally, you can concatenate any :py:class:`~pulse2percept.stimuli.Stimulus`
# object into a pulse train.
#
# For example, let's define a single ramp stimulus:

import numpy as np
from pulse2percept.stimuli import Stimulus, PulseTrain

# Single ramp:
dt = 1e-3
ramp = Stimulus([[0, 0, 1, 1, 2, 2, 0, 0]],
                time=[0, 1, 1 + dt, 2, 2 + dt, 3, 3 + dt, 5 - dt])
ramp.plot()

# Ramp train:
PulseTrain(20, ramp, stim_dur=200).plot()

# Biphasic ramp:
biphasic_ramp = Stimulus(np.concatenate((ramp.data, -ramp.data), axis=1),
                         time=np.concatenate((ramp.time, ramp.time + 5)))
biphasic_ramp.plot()

# Biphasic ramp train:
PulseTrain(20, biphasic_ramp, stim_dur=200).plot()
delay_dur = 12

# Vary this current to determine threshold:
amp_test = 45

# Cathodic phase of the test pulse (delivered after a delay):
cath_test = MonophasicPulse(-amp_test, phase_dur, delay_dur=delay_dur)

###############################################################################
# The anodic phase were always presented 20 ms after the second cathodic phase:

anod_standard = MonophasicPulse(0.5 * amp_th, phase_dur, delay_dur=20)

anod_test = MonophasicPulse(amp_test, phase_dur, delay_dur=delay_dur)

###############################################################################
# The last step is to concatenate all the pulses into a single stimulus:

from pulse2percept.stimuli import Stimulus

data = []
time = []
time_tracker = 0
for pulse in (cath_standard, cath_test, anod_standard, anod_test):
    data.append(pulse.data)
    time.append(pulse.time + time_tracker)
    time_tracker += pulse.time[-1]

latent_add = Stimulus(np.concatenate(data, axis=1), time=np.concatenate(time))
latent_add.plot()
Esempio n. 22
0
def test_AsymmetricBiphasicPulse(amp1, amp2, interphase_dur, delay_dur,
                                 cathodic_first):
    phase_dur1 = 2.1
    phase_dur2 = 4.87
    mid_first_pulse = delay_dur + phase_dur1 / 2.0
    mid_interphase = delay_dur + phase_dur1 + interphase_dur / 2.0
    mid_second_pulse = delay_dur + phase_dur1 + interphase_dur + phase_dur2 / 2
    first_amp = -np.abs(amp1) if cathodic_first else np.abs(amp1)
    second_amp = np.abs(amp2) if cathodic_first else -np.abs(amp2)
    min_dur = delay_dur + phase_dur1 + interphase_dur + phase_dur2

    # Basic usage:
    pulse = AsymmetricBiphasicPulse(amp1,
                                    amp2,
                                    phase_dur1,
                                    phase_dur2,
                                    interphase_dur=interphase_dur,
                                    delay_dur=delay_dur,
                                    cathodic_first=cathodic_first)
    npt.assert_almost_equal(pulse[0, 0], 0)
    npt.assert_almost_equal(pulse[0, mid_first_pulse], first_amp)
    npt.assert_almost_equal(pulse[0, mid_interphase], 0)
    npt.assert_almost_equal(pulse[0, mid_second_pulse], second_amp)
    npt.assert_almost_equal(pulse.time[0], 0)
    npt.assert_almost_equal(pulse.time[-1], min_dur, decimal=3)
    npt.assert_equal(pulse.cathodic_first, cathodic_first)
    npt.assert_equal(pulse.is_charge_balanced,
                     np.isclose(np.trapz(pulse.data, pulse.time)[0], 0))

    # Custom stim dur:
    pulse = AsymmetricBiphasicPulse(amp1,
                                    amp2,
                                    phase_dur1,
                                    phase_dur2,
                                    interphase_dur=interphase_dur,
                                    delay_dur=delay_dur,
                                    cathodic_first=cathodic_first,
                                    stim_dur=100,
                                    electrode='A1')
    npt.assert_almost_equal(pulse[0, 0], 0)
    npt.assert_almost_equal(pulse[0, mid_first_pulse], first_amp)
    npt.assert_almost_equal(pulse[0, mid_interphase], 0)
    npt.assert_almost_equal(pulse[0, mid_second_pulse], second_amp)
    npt.assert_almost_equal(pulse.time[0], 0)
    npt.assert_almost_equal(pulse.time[-1], 100)
    npt.assert_equal(pulse.electrodes, 'A1')

    # Exact stim dur:
    stim_dur = delay_dur + phase_dur1 + interphase_dur + phase_dur2
    pulse = AsymmetricBiphasicPulse(amp1,
                                    amp2,
                                    phase_dur1,
                                    phase_dur2,
                                    interphase_dur=interphase_dur,
                                    delay_dur=delay_dur,
                                    cathodic_first=cathodic_first,
                                    stim_dur=stim_dur,
                                    electrode='A1')
    npt.assert_almost_equal(pulse.time[0], 0)
    npt.assert_almost_equal(pulse.time[-1], stim_dur, decimal=6)

    # Zero amplitude:
    pulse = AsymmetricBiphasicPulse(0,
                                    0,
                                    phase_dur1,
                                    phase_dur2,
                                    interphase_dur=interphase_dur,
                                    delay_dur=delay_dur,
                                    cathodic_first=cathodic_first)
    npt.assert_almost_equal(pulse.data, 0)
    npt.assert_almost_equal(pulse.time[0], 0)
    npt.assert_almost_equal(pulse.time[-1], min_dur, decimal=3)
    npt.assert_equal(pulse.is_charge_balanced,
                     np.isclose(np.trapz(pulse.data, pulse.time)[0], 0))

    # If both phases have the same values, it's basically a symmetric biphasic
    # pulse:
    abp = AsymmetricBiphasicPulse(amp1,
                                  amp1,
                                  phase_dur1,
                                  phase_dur1,
                                  interphase_dur=interphase_dur,
                                  delay_dur=delay_dur,
                                  cathodic_first=cathodic_first)
    bp = BiphasicPulse(amp1,
                       phase_dur1,
                       interphase_dur=interphase_dur,
                       delay_dur=delay_dur,
                       cathodic_first=cathodic_first)
    bp_min_dur = phase_dur1 * 2 + interphase_dur + delay_dur
    npt.assert_almost_equal(abp[:, np.linspace(0, bp_min_dur, num=5)],
                            bp[:, np.linspace(0, bp_min_dur, num=5)])
    npt.assert_equal(abp.cathodic_first, bp.cathodic_first)

    # If one phase is zero, it's basically a monophasic pulse:
    abp = AsymmetricBiphasicPulse(amp1,
                                  0,
                                  phase_dur1,
                                  phase_dur2,
                                  interphase_dur=interphase_dur,
                                  delay_dur=delay_dur,
                                  cathodic_first=cathodic_first)
    mono = MonophasicPulse(first_amp,
                           phase_dur1,
                           delay_dur=delay_dur,
                           stim_dur=min_dur)
    npt.assert_almost_equal(abp[:, np.linspace(0, min_dur, num=5)],
                            mono[:, np.linspace(0, min_dur, num=5)])
    npt.assert_equal(abp.cathodic_first, mono.cathodic)

    # You can wrap a pulse in a Stimulus to overwrite attributes:
    stim = Stimulus(pulse, electrodes='AA1')
    npt.assert_equal(stim.electrodes, 'AA1')
    # Or concatenate:
    stim = Stimulus([pulse, pulse])
    npt.assert_equal(stim.shape[0], 2)
    npt.assert_almost_equal(stim.data[0, :], stim.data[1, :])
    npt.assert_almost_equal(stim.time, pulse.time, decimal=2)
    npt.assert_equal(stim.electrodes, [0, 1])
    # Concatenate and rename:
    stim = Stimulus([pulse, pulse], electrodes=['C1', 'D2'])
    npt.assert_equal(stim.electrodes, ['C1', 'D2'])

    # Invalid calls:
    with pytest.raises(ValueError):
        AsymmetricBiphasicPulse(amp1, amp2, 0, phase_dur2)
    with pytest.raises(ValueError):
        AsymmetricBiphasicPulse(amp1, amp2, phase_dur1, 0)
    with pytest.raises(ValueError):
        AsymmetricBiphasicPulse(amp1,
                                amp2,
                                phase_dur1,
                                phase_dur2,
                                interphase_dur=-1)
    with pytest.raises(ValueError):
        AsymmetricBiphasicPulse(amp1,
                                amp2,
                                phase_dur1,
                                phase_dur2,
                                interphase_dur=interphase_dur,
                                delay_dur=-1)
    with pytest.raises(ValueError):
        AsymmetricBiphasicPulse(amp1,
                                amp2,
                                phase_dur1,
                                phase_dur2,
                                interphase_dur=interphase_dur,
                                delay_dur=delay_dur,
                                stim_dur=1)
    with pytest.raises(ValueError):
        AsymmetricBiphasicPulse(amp1,
                                amp2,
                                phase_dur1,
                                phase_dur2,
                                interphase_dur=interphase_dur,
                                delay_dur=delay_dur,
                                electrode=['A1', 'B2'])
Esempio n. 23
0
def test_BiphasicPulse(amp, interphase_dur, delay_dur, cathodic_first):
    phase_dur = 3.19
    mid_first_pulse = delay_dur + phase_dur / 2.0
    mid_interphase = delay_dur + phase_dur + interphase_dur / 2.0
    mid_second_pulse = delay_dur + interphase_dur + 1.5 * phase_dur
    first_amp = -np.abs(amp) if cathodic_first else np.abs(amp)
    second_amp = -first_amp
    min_dur = 2 * phase_dur + delay_dur + interphase_dur

    # Basic usage:
    pulse = BiphasicPulse(amp,
                          phase_dur,
                          interphase_dur=interphase_dur,
                          delay_dur=delay_dur,
                          cathodic_first=cathodic_first)
    npt.assert_almost_equal(pulse[0, 0], 0)
    npt.assert_almost_equal(pulse[0, mid_first_pulse], first_amp)
    npt.assert_almost_equal(pulse[0, mid_interphase], 0)
    npt.assert_almost_equal(pulse[0, mid_second_pulse], second_amp)
    npt.assert_almost_equal(pulse.time[0], 0)
    npt.assert_almost_equal(pulse.time[-1], min_dur, decimal=3)
    npt.assert_equal(pulse.cathodic_first, cathodic_first)
    npt.assert_equal(pulse.is_charge_balanced, True)

    # Custom stim dur:
    pulse = BiphasicPulse(amp,
                          phase_dur,
                          interphase_dur=interphase_dur,
                          delay_dur=delay_dur,
                          cathodic_first=cathodic_first,
                          stim_dur=100,
                          electrode='B1')
    npt.assert_almost_equal(pulse[0, 0], 0)
    npt.assert_almost_equal(pulse[0, mid_first_pulse], first_amp)
    npt.assert_almost_equal(pulse[0, mid_interphase], 0)
    npt.assert_almost_equal(pulse[0, mid_second_pulse], second_amp)
    npt.assert_almost_equal(pulse.time[0], 0)
    npt.assert_almost_equal(pulse.time[-1], 100)
    npt.assert_equal(pulse.electrodes, 'B1')

    # Exact stim dur:
    stim_dur = 2 * phase_dur + interphase_dur + delay_dur
    pulse = BiphasicPulse(amp,
                          phase_dur,
                          interphase_dur=interphase_dur,
                          delay_dur=delay_dur,
                          cathodic_first=cathodic_first,
                          stim_dur=stim_dur)
    npt.assert_almost_equal(pulse.time[0], 0)
    npt.assert_almost_equal(pulse.time[-1], stim_dur, decimal=6)

    # Zero amplitude:
    pulse = BiphasicPulse(0,
                          phase_dur,
                          interphase_dur=interphase_dur,
                          delay_dur=delay_dur,
                          cathodic_first=cathodic_first)
    npt.assert_almost_equal(pulse.data, 0)
    npt.assert_almost_equal(pulse.time[0], 0)
    npt.assert_almost_equal(pulse.time[-1], min_dur, decimal=3)
    npt.assert_equal(pulse.is_charge_balanced, True)

    # You can wrap a pulse in a Stimulus to overwrite attributes:
    stim = Stimulus(pulse, electrodes='AA1')
    npt.assert_equal(stim.electrodes, ['AA1'])
    # Or concatenate:
    stim = Stimulus([pulse, pulse])
    npt.assert_equal(stim.shape[0], 2)
    npt.assert_almost_equal(stim.data[0, :], stim.data[1, :])
    npt.assert_almost_equal(stim.time, pulse.time)
    npt.assert_equal(stim.electrodes, [0, 1])
    # Concatenate and rename:
    stim = Stimulus([pulse, pulse], electrodes=['C1', 'D2'])
    npt.assert_equal(stim.electrodes, ['C1', 'D2'])

    # Floating point math with np.unique is tricky, but this works:
    BiphasicPulse(10,
                  np.pi,
                  interphase_dur=np.pi,
                  delay_dur=np.pi,
                  stim_dur=5 * np.pi)

    # Invalid calls:
    with pytest.raises(ValueError):
        BiphasicPulse(amp, 0)
    with pytest.raises(ValueError):
        BiphasicPulse(amp, phase_dur, interphase_dur=-1)
    with pytest.raises(ValueError):
        BiphasicPulse(amp,
                      phase_dur,
                      interphase_dur=interphase_dur,
                      delay_dur=-1)
    with pytest.raises(ValueError):
        BiphasicPulse(amp,
                      phase_dur,
                      interphase_dur=interphase_dur,
                      delay_dur=delay_dur,
                      stim_dur=1)
    with pytest.raises(ValueError):
        BiphasicPulse(amp,
                      phase_dur,
                      interphase_dur=interphase_dur,
                      delay_dur=delay_dur,
                      electrode=['A1', 'B2'])
Esempio n. 24
0
def test_Stimulus_plot():
    # Stimulus with one electrode
    stim = Stimulus([[0, -10, 10, -10, 10, -10, 0]],
                    time=[0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0])
    for time in [None, Ellipsis, slice(None)]:
        # Different ways to plot all data points:
        fig, ax = plt.subplots()
        stim.plot(time=time, ax=ax)
        npt.assert_equal(isinstance(ax, Subplot), True)
        npt.assert_almost_equal(
            ax.get_yticks(),
            [stim.data.min(), 0, stim.data.max()])
        npt.assert_equal(len(ax.lines), 1)
        npt.assert_almost_equal(ax.lines[0].get_data()[1].min(),
                                stim.data.min())
        npt.assert_almost_equal(ax.lines[0].get_data()[1].max(),
                                stim.data.max())
        plt.close(fig)

    # Plot a range of time values (times are sliced, but end points are
    # interpolated):
    fig, ax = plt.subplots()
    ax = stim.plot(time=(0.2, 0.6), ax=ax)
    npt.assert_equal(isinstance(ax, Subplot), True)
    npt.assert_equal(len(ax.lines), 1)
    t_vals = ax.lines[0].get_data()[0]
    npt.assert_almost_equal(t_vals[0], 0.2)
    npt.assert_almost_equal(t_vals[-1], 0.6)
    plt.close(fig)

    # Plot exact time points:
    t_vals = [0.2, 0.3, 0.4]
    fig, ax = plt.subplots()
    stim.plot(time=t_vals, ax=ax)
    npt.assert_equal(isinstance(ax, Subplot), True)
    npt.assert_equal(len(ax.lines), 1)
    npt.assert_almost_equal(ax.lines[0].get_data()[0], t_vals)
    npt.assert_almost_equal(ax.lines[0].get_data()[1],
                            np.squeeze(stim[:, t_vals]))

    # Plot multiple electrodes with string names:
    for n_electrodes in [2, 3, 4]:
        stim = Stimulus(np.random.rand(n_electrodes, 20),
                        electrodes=[f'E{i}' for i in range(n_electrodes)])
        fig, axes = plt.subplots(ncols=n_electrodes)
        stim.plot(ax=axes)
        npt.assert_equal(isinstance(axes, (list, np.ndarray)), True)
        for ax, electrode in zip(axes, stim.electrodes):
            npt.assert_equal(isinstance(ax, Subplot), True)
            npt.assert_equal(len(ax.lines), 1)
            npt.assert_equal(ax.get_ylabel(), electrode)
            npt.assert_almost_equal(ax.lines[0].get_data()[0], stim.time)
            npt.assert_almost_equal(ax.lines[0].get_data()[1],
                                    stim[electrode, :])
        plt.close(fig)

    # Invalid calls:
    with pytest.raises(TypeError):
        stim.plot(electrodes=1.2)
    with pytest.raises(TypeError):
        stim.plot(time=0)
    with pytest.raises(TypeError):
        stim.plot(ax='as')
    with pytest.raises(TypeError):
        stim.plot(time='0 0.1')
    with pytest.raises(NotImplementedError):
        Stimulus(np.ones(10)).plot()
    with pytest.raises(ValueError):
        stim = Stimulus(np.ones((3, 10)))
        _, axes = plt.subplots(nrows=4)
        stim.plot(ax=axes)
    with pytest.raises(TypeError):
        stim = Stimulus(np.ones((3, 10)))
        _, axes = plt.subplots(nrows=3)
        axes[1] = 0
        stim.plot(ax=axes)
Esempio n. 25
0
def test_Stimulus_append():
    # Basic usage:
    stim = Stimulus([[0, 1, 0]], time=[0, 1, 2])
    stim2 = Stimulus([[0, 2]], time=[0, 0.5])
    comb = stim.append(stim2)
    # End point of stim and starting point of stim2 will be merged:
    npt.assert_almost_equal(comb.data, [[0, 1, 0, 2]])
    npt.assert_almost_equal(comb.time, [0, 1, 2, 2.5])

    # When other stimulus is shifted:
    comb = stim.append(stim2 >> 10)
    npt.assert_almost_equal(comb.time, [0, 1, 2, 12, 12.5])

    with pytest.raises(TypeError):
        # 'other' must be Stimulus:
        stim.append(np.array([[0, 1, 2]]))
    with pytest.raises(ValueError):
        # other cannot have time=None:
        stim.append(Stimulus(3))
    with pytest.raises(ValueError):
        # self cannot have time=None:
        Stimulus(3).append(stim)
    with pytest.raises(ValueError):
        stim.append(Stimulus([[1, 2]], electrodes='B1'))
    with pytest.raises(NotImplementedError):
        # negative time axis:
        stim.append(Stimulus([[0, 2]], time=[-1, 0]))
Esempio n. 26
0
def test_Stimulus___eq__():
    # Two Stimulus objects created from the same source data are considered
    # equal:
    for source in [3, [], np.ones(3), [3, 4, 5], np.ones((3, 6))]:
        npt.assert_equal(Stimulus(source) == Stimulus(source), True)
    stim = Stimulus(np.ones((2, 3)), compress=True)
    # Compressed vs uncompressed:
    npt.assert_equal(stim == Stimulus(np.ones((2, 3)), compress=False), False)
    npt.assert_equal(stim != Stimulus(np.ones((2, 3)), compress=False), True)
    # Different electrode names:
    npt.assert_equal(stim == Stimulus(stim, electrodes=[0, 'A2']), False)
    # Different time points:
    npt.assert_equal(stim == Stimulus(stim, time=[0, 3], compress=True), False)
    # Different data shape:
    npt.assert_equal(stim == Stimulus(np.ones((2, 4))), False)
    npt.assert_equal(stim == Stimulus(np.ones(2)), False)
    # Different data points:
    npt.assert_equal(stim == Stimulus(np.ones((2, 3)) * 1.1, compress=True),
                     False)
    # Different shape
    npt.assert_equal(stim == Stimulus(np.ones((2, 5))), False)
    # Different type:
    npt.assert_equal(stim == ODict(), False)
    npt.assert_equal(stim != ODict(), True)
    # Time vs no time:
    npt.assert_equal(Stimulus(2) == stim, False)
    # Annoying but possible:
    npt.assert_equal(Stimulus([]), Stimulus(()))
Esempio n. 27
0
levels = np.linspace(-1, 1, num=5)
data = levels[np.argmin(np.abs(x[:, np.newaxis] - levels), axis=1)]

plt.plot(t, data, label='discretized')
plt.plot(t, x, label='original')
plt.xlabel('Time (ms)')
plt.ylabel('Amplitude')
plt.legend()

##############################################################################
# We can turn this signal into a :py:class:`~pulse2percept.stimuli.Stimulus`
# object as follows:

from pulse2percept.stimuli import Stimulus

stim = Stimulus(10 * data.reshape((1, -1)), time=t)
stim.plot()

##############################################################################
# Alternatively, we can automate this process by creating a new class
# ``SinusoidalPulse`` that inherits from ``Stimulus``:


class SinusoidalPulse(Stimulus):
    def __init__(self, amp, freq, phase, stim_dur, n_levels=5, dt=0.001):
        """Sinusoidal pulse

        Parameters
        ----------
        amp : float
            Maximum stimulus amplitude (uA)
Esempio n. 28
0
def test_Stimulus___getitem__():
    stim = Stimulus(1 + np.arange(12).reshape((3, 4)))
    # Slicing:
    npt.assert_equal(stim[:], stim.data)
    npt.assert_equal(stim[...], stim.data)
    npt.assert_equal(stim[:, :], stim.data)
    npt.assert_equal(stim[:2], stim.data[:2])
    npt.assert_equal(stim[:, 0.0], stim.data[:, 0].reshape((-1, 1)))
    npt.assert_equal(stim[0, :], stim.data[0, :])
    npt.assert_equal(stim[0, ...], stim.data[0, ...])
    npt.assert_equal(stim[..., 0], stim.data[..., 0].reshape((-1, 1)))
    # More advanced slicing of time is possible, but needs a step size:
    with pytest.raises(ValueError):
        stim[:, 2:5]
    with pytest.raises(ValueError):
        stim[:, :3]
    with pytest.raises(ValueError):
        stim[:, 2:]
    npt.assert_almost_equal(stim[0, 1.2:1.65:0.15], [[2.2, 2.35, 2.5]])
    npt.assert_almost_equal(stim[0, :0.6:0.2], [[1.0, 1.2, 1.4]])
    npt.assert_almost_equal(stim[0, 2.7::0.2], [[3.7, 3.9]])
    npt.assert_almost_equal(stim[0, ::2.6], [[1.0, 3.6]])
    # Single element:
    npt.assert_equal(stim[0, 0], stim.data[0, 0])
    # Interpolating time:
    npt.assert_almost_equal(stim[0, 2.6], 3.6)
    npt.assert_almost_equal(stim[..., 2.3],
                            np.array([[3.3], [7.3], [11.3]]),
                            decimal=3)
    # The second dimension is not a column index!
    npt.assert_almost_equal(stim[0, 0], 1.0)
    npt.assert_almost_equal(stim[0, [0, 1]], np.array([[1.0, 2.0]]))
    npt.assert_almost_equal(stim[0, [0.21, 1]], np.array([[1.21, 2.0]]))
    npt.assert_almost_equal(stim[[0, 1], [0.21, 1]],
                            np.array([[1.21, 2.0], [5.21, 6.0]]))

    # "Valid" index errors:
    with pytest.raises(IndexError):
        stim[10, :]
    with pytest.raises(IndexError):
        stim[3.3, 0]

    # Times can be extrapolated (take on value of end points):
    stim = Stimulus(1 + np.arange(12).reshape((3, 4)))
    npt.assert_almost_equal(stim[0, 9.901], 4)
    # If time=None, you cannot interpolate/extrapolate:
    stim = Stimulus([3, 4, 5])
    npt.assert_almost_equal(stim[0], stim.data[0, 0])
    with pytest.raises(ValueError):
        stim[0, 0.2]

    # With a single time point, interpolate is still possible:
    stim = Stimulus(np.arange(3).reshape((-1, 1)))
    npt.assert_almost_equal(stim[0], stim.data[0, 0])
    npt.assert_almost_equal(stim[0, 0], stim.data[0, 0])
    npt.assert_almost_equal(stim[0, 3.33], stim.data[0, 0])

    # Annoying but possible:
    stim = Stimulus([])
    npt.assert_almost_equal(stim[:], stim.data)
    with pytest.raises(IndexError):
        stim[0]

    # Electrodes by string:
    stim = Stimulus([[0, 1], [2, 3]], electrodes=['A1', 'B2'])
    npt.assert_almost_equal(stim['A1'], [0, 1])
    npt.assert_almost_equal(stim['A1', :], [0, 1])
    npt.assert_almost_equal(stim[['A1', 'B2'], 0], [[0], [2]])
    npt.assert_almost_equal(stim[['A1', 'B2'], :], stim.data)

    # Electrodes by slice:
    stim = Stimulus(np.arange(10))
    npt.assert_almost_equal(stim[1::3], np.array([[1], [4], [7]]))

    # Binary arrays:
    stim = Stimulus(np.arange(6).reshape((2, 3)),
                    electrodes=['A1', 'B2'],
                    time=[0.1, 0.3, 0.5])
    npt.assert_almost_equal(stim[stim.electrodes != 'A1', :], [[3, 4, 5]])
    npt.assert_almost_equal(stim[stim.electrodes == 'B2', :], [[3, 4, 5]])
    npt.assert_almost_equal(stim[stim.electrodes == 'C9', :], np.zeros((0, 3)))
    npt.assert_almost_equal(stim[stim.electrodes == 'C9', 0.1].size, 0)
    npt.assert_almost_equal(stim[stim.electrodes == 'B2', 0.1001],
                            3.0005,
                            decimal=3)
    npt.assert_almost_equal(stim[stim.electrodes == 'B2', 0.2], 3.5)
    npt.assert_almost_equal(stim[:, stim.time < 0.4], [[0, 1], [3, 4]])
    npt.assert_almost_equal(stim[stim.electrodes == 'B2', stim.time < 0.4],
                            [3, 4])
    npt.assert_almost_equal(stim[:, stim.time > 0.6], np.zeros((2, 0)))
    npt.assert_almost_equal(stim['A1', stim.time > 0.6].size, 0)
    npt.assert_almost_equal(stim['A1', np.isclose(stim.time, 0.3)], [1])
Esempio n. 29
0
###############################################################################
# Now we need to activate one electrode at a time, and predict the resulting
# percept. We could build a :py:class:`~pulse2percept.stimuli.Stimulus` object
# with a for loop that does just that, or we can use the following trick.
#
# The stimulus' data container is a (electrodes, timepoints) shaped 2D NumPy
# array. Activating one electrode at a time is therefore the same as an
# identity matrix whose size is equal to the number of electrodes. In code:

# Find the names of all the electrodes in the dataset:
electrodes = data.electrode.unique()
# Activate one electrode at a time:
import numpy as np
from pulse2percept.stimuli import Stimulus
argus.stim = Stimulus(np.eye(len(electrodes)), electrodes=electrodes)

###############################################################################
# Using the model's
# :py:func:`~pulse2percept.models.AxonMapModel.predict_percept`, we then get
# a Percept object where each frame is the percept generated from activating
# a single electrode:

percepts = model.predict_percept(argus)
percepts.play()

###############################################################################
# Finally, we can visualize the ground-truth and simulated phosphenes
# side-by-side:

from pulse2percept.viz import plot_argus_simulated_phosphenes
Esempio n. 30
0
##############################################################################
# Simplest stimulus
# ---------------------
# :py:class:`~pulse2percept.stimuli.Stimulus` is the base class to generate
# different types of stimuli. The simplest way to instantiate a Stimulus is
# to pass a scalar value which is interpreted as the current amplitude
# for a single electrode.

# Let's start by importing necessary modules
from pulse2percept.stimuli import (MonophasicPulse, BiphasicPulse, Stimulus,
                                   PulseTrain)

import numpy as np

stim = Stimulus(10)

##############################################################################
# Parameters we don't specify will take on default values. We can inspect
# all current model parameters as follows:

print(stim)

##############################################################################
# This command also reveals a number of other parameters to set, such as:
#
# * ``electrodes``: We can either specify the electrodes in the source
#   or within the stimulus. If none are specified it looks up the source
#   electrode.
#
# * ``metadata``: Optionally we can include metadata to the stimulus we