Exemplo n.º 1
0
 def _generate_timeseries(self, inp_data, tgt_data, timestep_start):
     self.num_timesteps = inp_data.shape[0]
     self.times = (np.arange(self.num_timesteps) + timestep_start) * self.dt
     t_stop = (timestep_start + self.num_timesteps) * self.dt
     self.input = TSContinuous(self.times, inp_data, t_stop=t_stop)
     self.target = TSContinuous(self.times, tgt_data)
     self.duration = self.num_timesteps * self.dt
Exemplo n.º 2
0
def test_continuous_call():
    from rockpool import TSContinuous

    # - Generate series
    times = np.arange(1, 5) * 0.1
    samples = np.arange(4).reshape(-1, 1) + np.arange(2) * 2
    ts = TSContinuous(times, samples)
    ts_empty = TSContinuous()
    ts_single = TSContinuous(2, [3, 2])

    # - Call ts
    assert np.allclose(ts(0.1), np.array([[0, 2]]))
    assert np.allclose(ts(0.25), np.array([[1.5, 3.5]]))
    assert (np.isnan(ts(0))).all() and ts(0).shape == (1, 2)
    samples_ts = ts([0, 0.1, 0.25])
    assert np.allclose(samples_ts[1:], np.array([[0, 2], [1.5, 3.5]]))
    assert (np.isnan(samples_ts[0])).all() and samples_ts.shape == (3, 2)

    # - Call ts_empty
    assert ts_empty(1).shape == (1, 0)
    assert ts_empty([0, 1, 4]).shape == (3, 0)

    # - Call ts_single
    assert (ts_single(2) == np.array([[3, 2]])).all()
    assert np.isnan(ts_single(0)).all() and ts_single(0).shape == (1, 2)
    samples_single = ts_single([0, 2, 1, 3, 2, 4])
    assert (np.isnan(samples_single) == np.array([
        2 * [True], 2 * [False], 2 * [True], 2 * [True], 2 * [False],
        2 * [True]
    ])).all()
    assert (samples_single[1] == np.array([3, 2])).all()
    assert (samples_single[4] == np.array([3, 2])).all()
Exemplo n.º 3
0
def test_ForceRateEulerJax():
    """ Test ForceRateEulerJax """
    from rockpool import TSContinuous
    from rockpool.layers import ForceRateEulerJax

    # - Generic parameters
    w_in = 2 * np.random.rand(1, 2) - 1
    w_out = 2 * np.random.rand(2, 1) - 1
    bias = 2 * np.random.rand(2) - 1
    tau = 20e-3 * np.ones(2)

    # - Layer generation
    fl0 = ForceRateEulerJax(w_in=w_in,
                            w_out=w_out,
                            bias=bias,
                            noise_std=0.1,
                            tau=tau,
                            dt=0.01)

    # - Input signal
    tsInCont = TSContinuous(times=np.arange(15) * 0.01,
                            samples=np.ones((15, 1)))

    tsForceCont = TSContinuous(times=np.arange(15) * 0.01,
                               samples=np.ones((15, 2)))

    # - Compare states and time before and after
    vStateBefore = np.copy(fl0.state)
    ts_output = fl0.evolve(tsInCont, tsForceCont, duration=0.1)
    assert fl0.t == 0.1
    assert (vStateBefore != fl0.state).any()

    fl0.reset_all()
    assert fl0.t == 0
    assert (vStateBefore == fl0.state).all()

    # - Test that some errors are caught
    with pytest.raises(AssertionError):
        fl1 = ForceRateEulerJax(
            w_in=np.zeros((1, 2)),
            w_out=np.zeros((3, 1)),
            tau=np.zeros(3),
            bias=np.zeros(3),
        )

    with pytest.raises(AssertionError):
        fl1 = ForceRateEulerJax(
            w_in=np.zeros((1, 2)),
            w_out=np.zeros((2, 1)),
            tau=np.zeros(3),
            bias=np.zeros(3),
        )

    with pytest.raises(AssertionError):
        fl1 = ForceRateEulerJax(
            w_in=np.zeros((1, 2)),
            w_out=np.zeros((2, 1)),
            tau=np.zeros(2),
            bias=np.zeros(3),
        )
Exemplo n.º 4
0
def test_save_load():
    """
    Test saving and loading function for timeseries
    """
    from rockpool import TSEvent, TSContinuous, load_ts_from_file
    from os import remove

    # - Generate time series objects
    times = [1, 3, 6]
    samples = np.random.randn(3)
    channels = [0, 1, 1]
    tsc = TSContinuous(times,
                       samples,
                       t_start=-1,
                       t_stop=8,
                       periodic=True,
                       name="continuous")
    tse = TSEvent(
        times,
        channels,
        num_channels=3,
        t_start=-1,
        t_stop=8,
        periodic=True,
        name="events",
    )
    # - Store objects
    tsc.save("test_tsc")
    tse.save("test_tse")
    # - Load objects
    tscl = load_ts_from_file("test_tsc.npz")
    tsel = load_ts_from_file("test_tse.npz")
    # - Verify that attributes are still correct
    assert (tscl.times == times).all(), "TSContinuous: times changed."
    assert (tscl.samples == samples.reshape(
        -1, 1)).all(), "TSContinuous: samples changed."
    assert tscl.name == "continuous", "TSContinuous: name changed."
    assert tscl.t_start == -1, "TSContinuous: t_start changed."
    assert tscl.t_stop == 8, "TSContinuous: t_stop changed."
    assert tscl.periodic, "TSContinuous: periodic changed."
    assert (tsel.times == times).all(), "TSEvent: times changed."
    assert (tsel.channels == channels).all(), "TSEvent: channels changed."
    assert tsel.name == "events", "TSEvent: name changed."
    assert tsel.t_start == -1, "TSEvent: t_start changed."
    assert tsel.t_stop == 8, "TSEvent: t_stop changed."
    assert tsel.periodic, "TSEvent: periodic changed."
    assert tsel.num_channels == 3, "TSEvent: num_channels changed."
    # - Remove saved files
    remove("test_tsc.npz")
    remove("test_tse.npz")
Exemplo n.º 5
0
def test_updown():
    """ Test FFUpDown """
    from rockpool import TSContinuous
    from rockpool.layers import FFUpDown

    # - Generic parameters
    weights = np.random.rand(2, 4)

    # - Layer generation
    fl0 = FFUpDown(weights=weights, dt=0.01, thr_down=0.02, thr_up=0.01, tau_decay=0.1)

    # - Check layer properties
    assert fl0.size == 4, "Problem with size"
    assert fl0.size_in == 2, "Problem with size_in"
    assert (fl0.thr_down == np.array([0.02, 0.02])).all(), "Problem with thr_down"
    assert (fl0.thr_up == np.array([0.01, 0.01])).all(), "Problem with thr_up"

    # - Input signal
    tsInCont = TSContinuous(
        times=np.arange(15) * 0.01,
        samples=np.vstack(
            (np.sin(np.linspace(0, 1, 15)), np.cos(np.linspace(0, 1, 15)))
        ).T,
    )

    # - Compare states and time before and after evolution
    vStateBefore = np.copy(fl0.state)
    fl0.evolve(tsInCont, duration=0.1)
    assert fl0.t == 0.1
    assert (vStateBefore != fl0.state).any()

    fl0.reset_all()
    assert fl0.t == 0
    assert (vStateBefore == fl0.state).all()
Exemplo n.º 6
0
def test_updown_in_net():
    """ Test RecRateEuler """
    from rockpool import TSContinuous
    from rockpool.networks import Network
    from rockpool.layers import FFUpDown
    from rockpool.layers import RecDIAF

    # - Generic parameters
    weights = np.random.rand(2, 4)

    # - Layer generation
    fl0 = FFUpDown(weights=weights)
    fl1 = RecDIAF(np.zeros((4, 2)), np.zeros((2, 2)), dt=0.002)
    # - Generate network
    net = Network(fl0, fl1)

    # - Input signal
    tsInCont = TSContinuous(
        times=np.arange(15) * 0.01,
        samples=np.vstack(
            (np.sin(np.linspace(0, 1, 15)), np.cos(np.linspace(0, 1, 15)))
        ).T,
    )

    # - Compare states and time before and after evolution
    vStateBefore = np.copy(fl1.state)
    net.evolve(tsInCont, duration=0.1)
    assert net.t == 0.1
    assert (vStateBefore != fl1.state).any()
Exemplo n.º 7
0
def test_continuous_merge():
    """
    Test merge method of TSContinuous
    """
    from rockpool import TSContinuous

    # - Generate a few TSEvent objects
    samples = np.random.randint(10, size=(2, 6))
    empty_series = TSContinuous(t_start=1)
    series_list = []
    series_list.append(
        TSContinuous([1, 2], samples[:2, :2], t_start=-1, t_stop=3))
    series_list.append(
        TSContinuous([1.5], samples[0, 2:4], t_start=0, t_stop=4))
    series_list.append(
        TSContinuous([2, 2.5], samples[:2, 4:6], t_start=-2, t_stop=3))

    # Merging two series
    merged_fromtwo = series_list[0].merge(series_list[1])
    assert merged_fromtwo.t_start == -1, "Wrong t_start for merged series."
    assert merged_fromtwo.t_stop == 4, "Wrong t_stop for merged series."
    assert (merged_fromtwo.times == np.array(
        [1, 1.5, 2])).all(), "Wrong time trace for merged series."
    correct_samples = np.vstack((samples[0, :2], samples[0,
                                                         2:4], samples[1, :2]))
    assert (merged_fromtwo.samples == correct_samples
            ).all(), "Wrong samples for merged series."

    # Merging with empty series
    merged_empty_first = empty_series.merge(series_list[0])
    assert (merged_empty_first.t_start == series_list[0].t_start
            ), "Wrong t_start when merging with empty."
    assert (merged_empty_first.t_stop == series_list[0].t_stop
            ), "Wrong t_stop when merging with empty."
    assert (merged_empty_first.samples == series_list[0].samples
            ).all(), "Wrong samples when merging with empty"
    assert (merged_empty_first.times == series_list[0].times
            ).all(), "Wrong time trace when merging with empty"

    merged_empty_last = series_list[0].merge(empty_series)
    assert (merged_empty_last.t_start == series_list[0].t_start
            ), "Wrong t_start when merging with empty."
    assert (merged_empty_last.t_stop == series_list[0].t_stop
            ), "Wrong t_stop when merging with empty."
    assert (merged_empty_last.samples == series_list[0].samples
            ).all(), "Wrong samples when merging with empty"
    assert (merged_empty_last.times == series_list[0].times
            ).all(), "Wrong time trace when merging with empty"

    # Merging with list of series
    merged_with_list = empty_series.merge(series_list, remove_duplicates=True)
    assert (merged_with_list.num_channels == 2
            ), "Wrong channel count when merging with list."
    assert merged_with_list.t_start == -2, "Wrong t_start when merging with list."
    assert merged_with_list.t_stop == 4, "Wrong t_stop when merging with list."
    assert (merged_with_list.times == np.array(
        [1, 1.5, 2, 2.5])).all(), "Wrong time trace when merging with list."
    assert (merged_with_list.samples == np.vstack(
        (samples[0, :2], samples[0, 2:4], samples[1, :2],
         samples[1, 4:6]))).all(), "Wrong samples when merging with list."
Exemplo n.º 8
0
def test_continuous_append_c():
    """
    Test append_c method of TSContinuous
    """
    from rockpool import TSContinuous

    # - Generate a few TSContinuous objects
    samples = np.random.randint(10, size=(2, 4))
    empty_series = TSContinuous()
    series_list = []
    series_list.append(
        TSContinuous([1, 2], samples[:2, :2], t_start=-1, t_stop=2))
    series_list.append(TSContinuous([1], samples[0, -2:], t_start=0, t_stop=2))

    # Appending two series
    appended_fromtwo = series_list[0].append_c(series_list[1])
    assert appended_fromtwo.t_start == -1, "Wrong t_start for appended series."
    assert appended_fromtwo.t_stop == 2, "Wrong t_stop for appended series."
    assert (appended_fromtwo.times == np.array(
        [1, 2])).all(), "Wrong time trace for appended series."
    assert ((appended_fromtwo.samples[:, :2] == samples[:2, :2]).all()
            and (appended_fromtwo.samples[0, -2:] == samples[0, -2:]).all()
            and (np.isnan(appended_fromtwo.samples[1, -2:])).all()
            ), "Wrong samples for appended series."

    # Appending with empty series
    appended_empty_first = empty_series.append_c(series_list[0])
    assert (appended_empty_first.t_start == empty_series.t_start
            ), "Wrong t_start when appending with empty."
    assert (appended_empty_first.t_stop == empty_series.t_stop
            ), "Wrong t_stop when appending with empty."
    assert (appended_empty_first.samples.flatten() == empty_series.samples.
            flatten()).all(), "Wrong samples when appending with empty"
    assert (appended_empty_first.times == empty_series.times
            ).all(), "Wrong time trace when appending with empty"

    appended_empty_last = series_list[0].append_c(empty_series)
    assert (appended_empty_last.t_start == series_list[0].t_start
            ), "Wrong t_start when appending with empty."
    assert (appended_empty_last.t_stop == series_list[0].t_stop
            ), "Wrong t_stop when appending with empty."
    assert (appended_empty_last.samples == series_list[0].samples
            ).all(), "Wrong samples when appending with empty"
    assert (appended_empty_last.times == series_list[0].times
            ).all(), "Wrong time trace when appending with empty"
Exemplo n.º 9
0
def test_continuous_operators():
    """
    Test creation and manipulation of a continuous time series
    """
    from rockpool import TSContinuous

    # - Creation
    ts = TSContinuous([0], [0])
    ts = TSContinuous([0, 1, 2, 3], [1, 2, 3, 4])
    ts2 = TSContinuous([1, 2, 3, 4], [5, 6, 7, 8])

    # - Samples don't match time
    with pytest.raises(ValueError):
        TSContinuous([0, 1, 2], [0])

    # - Addition
    ts = ts + 1
    ts += 5
    ts = ts + ts2
    ts += ts2

    # - Subtraction
    ts = ts - 3
    ts -= 2
    ts = ts - ts2
    ts -= ts2

    # - Multiplication
    ts = ts * 0.9
    ts *= 0.2
    ts = ts * ts2
    ts *= ts2

    # - Division
    ts = ts / 2.0
    ts /= 1.0
    ts = ts / ts2
    ts /= ts2

    # - Floor division
    ts = ts // 1.0
    ts //= 1.0
    ts = ts // ts2
    ts //= ts2
Exemplo n.º 10
0
def test_largescale():
    from rockpool import TSEvent, TSContinuous
    from rockpool.layers import RecLIFCurrentInJax, RecLIFJax, RecLIFJax_IO

    # Numpy
    import numpy as np

    # - Define network
    N = 200
    Nin = 500
    Nout = 1

    tau_mem = 50e-3
    tau_syn = 100e-3
    bias = 0.0

    def rand_params(N, Nin, Nout, tau_mem, tau_syn, bias):
        return {
            "w_in": np.random.rand(Nin, N) - 0.5,
            "w_recurrent": 0.1 * np.random.randn(N, N) / np.sqrt(N),
            "w_out": 2 * np.random.rand(N, Nout) - 1,
            "tau_mem": tau_mem,
            "tau_syn": tau_syn,
            "bias": (np.ones(N) * bias).reshape(N),
        }

    # - Build a random network
    params0 = rand_params(N, Nin, Nout, tau_mem, tau_syn, bias)
    lyrIO = RecLIFJax_IO(**params0)

    # - Define input and target
    numRepeats = 1
    dur_input = 1000e-3
    dt = 1e-3
    T = int(np.round(dur_input / dt))

    timebase = np.linspace(0, T * dt, T)

    trigger = np.atleast_2d(timebase < 50e-3).T

    chirp = np.atleast_2d(np.sin(timebase * 2 * np.pi * (timebase * 10))).T
    target_ts = TSContinuous(timebase, chirp, periodic=True, name="Target")

    spiking_prob = 0.01
    sp_in_ts = np.random.rand(T * numRepeats, Nin) < spiking_prob * trigger
    spikes = np.argwhere(sp_in_ts)
    input_sp_ts = TSEvent(
        timebase[spikes[:, 0]],
        spikes[:, 1],
        name="Input",
        periodic=True,
        t_start=0.0,
        t_stop=dur_input,
    )

    lyrIO.evolve(input_sp_ts)
Exemplo n.º 11
0
def test_continuous_indexing():
    from rockpool import TSContinuous

    # - Generate series
    times = np.arange(6) * 0.1
    samples = np.arange(6).reshape(-1, 1) + np.arange(4)
    ts = TSContinuous(times, samples)

    # - Indexing time
    ts0 = ts[:]
    assert (ts0.times == times).all()
    assert (ts0.samples == samples).all()

    ts1 = ts[0.2:0.5]
    assert (ts1.times == times[[2, 3, 4]]).all()
    assert (ts1.samples == samples[[2, 3, 4]]).all()

    ts2 = ts[0.1:0.5:-0.2]
    assert (ts2.times == times[[3, 1]]).all()
    assert (ts2.samples == samples[[3, 1]]).all()

    ts3 = ts[[0.5, 0, 0.1]]
    assert (ts3.times == times[[5, 0, 1]]).all()
    assert (ts3.samples == samples[[5, 0, 1]]).all()
    ts4 = ts[0.2]
    assert (ts4.times == times[[2]]).all()
    assert (ts4.samples == samples[2]).all()

    # - Indexing channels
    ts0 = ts[:, :]
    assert (ts0.times == times).all()
    assert (ts0.samples == samples).all()

    ts1 = ts[None, 1:3]
    assert (ts1.times == times).all()
    assert (ts1.samples == samples[:, 1:3]).all()

    ts2 = ts[:, 1::-2]
    assert (ts2.times == times).all()
    assert (ts2.samples == samples[:, [3, 1]]).all()

    ts3 = ts[None, [2, 0, 3]]
    assert (ts3.times == times).all()
    assert (ts3.samples == samples[:, [2, 0, 3]]).all()

    ts4 = ts[:, 2]
    assert (ts4.times == times).all()
    assert (ts4.samples == samples[:, [2]]).all()

    # - Indexing channels and time
    ts0 = ts[:0.4, [3, 1]]
    assert (ts0.times == times[:4]).all()
    assert (ts0.samples == samples[:4, [3, 1]]).all()
Exemplo n.º 12
0
def test_evolve():
    from rockpool import TSContinuous

    # - Get a network
    netRes = test_build()

    # - Generate an input
    time_trace = np.linspace(0, 10, 100)
    ts_input = TSContinuous(time_trace, np.random.rand(100))

    # - Evolve the network
    resp = netRes.evolve(ts_input)
Exemplo n.º 13
0
def test_adam():
    from rockpool.layers import RecRateEulerJax
    from rockpool.layers.training import add_train_output
    from rockpool import TSContinuous

    # - Generic parameters
    w_in = 2 * np.random.rand(1, 2) - 1
    w_recurrent = 2 * np.random.rand(2, 2) - 1
    w_out = 2 * np.random.rand(2, 1) - 1

    # - Layer generation
    fl0 = RecRateEulerJax(
        w_in=w_in,
        w_recurrent=w_recurrent,
        w_out=w_out,
        bias=0,
        noise_std=0.1,
        tau=20,
        dt=1,
    )

    # - Add training shim
    fl0 = add_train_output(fl0)

    # - Define simple input and target
    ts_input = TSContinuous([0, 1, 2, 3], [0, 1, 0, 0])
    ts_target = TSContinuous([0, 1, 2, 3], [0.1, 0.2, 0.3, 0.4])

    # - Initialise training
    loss_fcn, grad_fcn = fl0.train_adam(ts_input, ts_target, is_first=True)

    # - Test loss and gradient functions
    loss_fcn()
    grad_fcn()

    # - Perform intermediate training step
    fl0.train_adam(ts_input, ts_target)

    # - Perform final training step
    fl0.train_adam(ts_input, ts_target, is_last=True)
Exemplo n.º 14
0
def test_continuous_clip():
    from rockpool import TSContinuous

    # - Generate series
    times = np.arange(1, 6) * 0.1
    samples = np.arange(5).reshape(-1, 1) + np.arange(2) * 2
    ts = TSContinuous(times, samples)
    ts_empty = TSContinuous()

    # - Clip ts in time
    assert (ts.clip(0.2, 0.4, include_stop=True).times == times[1:4]).all()
    assert (ts.clip(0.2, 0.4, include_stop=True).samples == samples[1:4]).all()
    assert (ts.clip(0.2, 0.4, include_stop=False).times == times[1:3]).all()
    assert (ts.clip(0.2, 0.4,
                    include_stop=False).samples == samples[1:3]).all()
    ts_limits = ts.clip(0.2, 0.35, include_stop=True, sample_limits=True)
    assert np.allclose(ts_limits.times, np.array([0.2, 0.3, 0.35]))
    expected_samples = np.vstack((samples[1:3], [2.5, 4.5]))
    assert np.allclose(ts_limits.samples, expected_samples)
    ts_beyond = ts.clip(0.4, 0.6, sample_limits=False)
    assert (ts_beyond.times == times[-2:]).all()
    assert (ts_beyond.samples == samples[-2:]).all()
    # - Clip ts channels
    assert (ts.clip(channels=1).times == times).all()
    assert (ts.clip(channels=1).samples == samples[:, [1]]).all()
    assert (ts.clip(channels=[1, 0]).times == times).all()
    assert (ts.clip(channels=[1, 0]).samples == samples[:, [1, 0]]).all()
    # - Clip ts channels and time
    ts_ch_t = ts.clip(0.2, 0.4, channels=1, include_stop=True)
    assert (ts_ch_t.times == times[1:4]).all()
    assert (ts_ch_t.samples == samples[1:4, [1]]).all()

    # - Clip empty
    ts_et = ts_empty.clip(0.2, 0.4, sample_limits=False)
    assert ts_et.isempty()
    assert ts_et.t_start == 0.2 and ts_et.t_stop == 0.4
    with pytest.raises(IndexError):
        ts_empty.clip(channels=0)
        ts_empty.clip(2, 4, channels=0)
Exemplo n.º 15
0
def test_FFRateEuler():
    """ Test FFRateEuler """
    from rockpool import TSContinuous
    from rockpool.layers import FFRateEuler

    # - Generic parameters
    weights = 2 * np.random.rand(2, 3) - 1
    bias = 2 * np.random.rand(3) - 1

    # - Layer generation
    fl0 = FFRateEuler(weights=weights, bias=bias, noise_std=0.1, dt=0.01)

    # - Input signal
    tsInCont = TSContinuous(times=np.arange(15) * 0.01,
                            samples=np.ones((15, 2)))

    # - Compare states and time before and after
    vStateBefore = np.copy(fl0.state)
    fl0.evolve(tsInCont, duration=0.1)
    assert fl0.t == 0.1
    assert (vStateBefore != fl0.state).any()

    fl0.reset_all()
    assert fl0.t == 0
    assert (vStateBefore == fl0.state).all()

    # - Test that some errors are caught
    with pytest.raises(TypeError):
        fl1 = FFRateEuler(weights=None)

    with pytest.raises(AssertionError):
        fl1 = FFRateEuler(weights=1, bias=[1, 1])

    with pytest.raises(AssertionError):
        fl1 = FFRateEuler(weights=1, tau=[1, 1])

    with pytest.raises(AssertionError):
        fl1 = FFRateEuler(weights=1, gain=[1, 1])
Exemplo n.º 16
0
    tau_decay=np.array((None, None)),
    name="updown",
    multiplex_spikes=True,
)

# - Load data
classprobs = {0: 0.8, 1: 0.05, 2: 0, 3: 0.05, 4: 0.05, 5: 0, 18: 0.05}
annotations, recordings = rec.load_from_file(rec.save_path)
data_in, data_tgt, anno_curr = rec.generate_data(
    num_beats=1000,
    annotations=annotations,
    rec_data=recordings,
    use_recordings=None,
    probs=classprobs,
    include_bad_signal=False,
    min_len_segment=2,
    use_cont_segments=True,
)
rhythm_starts = anno_curr.idx_new_start

# - Convert to spikes
times = np.arange(data_in.shape[0]) * dt_in
spike_input = aslayer.evolve(TSContinuous(times, data_in))

# - Load to chip
reservoir.load_events(
    tsAS=spike_input,
    vtRhythmStart=np.array(rhythm_starts) * dt_in,
    tTotalDuration=(data_in.shape[0] + 1) * dt_in,
)
Exemplo n.º 17
0
def test_training_FFwd():
    from rockpool import TSEvent, TSContinuous
    from rockpool.layers import RecLIFCurrentInJax, RecLIFJax, RecLIFJax_IO, FFLIFJax_IO
    import numpy as np

    N = 100
    Nin = 100
    Nout = 1

    tau_mem = 50e-3
    tau_syn = 100e-3
    bias = 0.0
    dt = 1e-3

    def rand_params(N, Nin, Nout, tau_mem, tau_syn, bias):
        return {
            "w_in": (np.random.rand(Nin, N) - 0.5) / Nin,
            "w_out": 2 * np.random.rand(N, Nout) - 1,
            "tau_mem": tau_mem,
            "tau_syn": tau_syn,
            "bias": (np.ones(N) * bias).reshape(N),
        }

    # - Generate a network
    params0 = rand_params(N, Nin, Nout, tau_mem, tau_syn, bias)
    lyrIO = FFLIFJax_IO(**params0, dt=dt)

    # - Define input and target
    numRepeats = 1
    dur_input = 1000e-3
    dt = 1e-3
    T = int(np.round(dur_input / dt))

    timebase = np.linspace(0, T * dt, T)

    trigger = np.atleast_2d(timebase < dur_input).T

    chirp = np.atleast_2d(np.sin(timebase * 2 * np.pi * (timebase * 10))).T
    target_ts = TSContinuous(timebase, chirp, periodic=True, name="Target")

    spiking_prob = 0.01
    sp_in_ts = np.random.rand(T * numRepeats, Nin) < spiking_prob * trigger
    spikes = np.argwhere(sp_in_ts)
    input_sp_ts = TSEvent(
        timebase[spikes[:, 0]],
        spikes[:, 1],
        name="Input",
        periodic=True,
        t_start=0,
        t_stop=dur_input,
    )

    # - Simulate initial network state
    lyrIO.randomize_state()
    lyrIO.evolve(input_sp_ts)

    # - Add training shim
    from rockpool.layers.training import add_shim_lif_jax_sgd

    lyrIO = add_shim_lif_jax_sgd(lyrIO)

    # - Train
    steps = 100
    for t in range(steps):
        lyrIO.randomize_state()
        l_fcn, g_fcn = lyrIO.train_output_target(input_sp_ts,
                                                 target_ts,
                                                 is_first=(t == 0))

        l_fcn()
        g_fcn()
Exemplo n.º 18
0
def test_RecLIFCurrentInJax():
    """ Test RecLIFCurrentInJax """
    from rockpool import TSContinuous
    from rockpool.layers import RecLIFCurrentInJax

    # - Generic parameters
    net_size = 2
    dt = 1e-3

    w_recurrent = 2 * np.random.rand(net_size, net_size) - 1
    bias = 2 * np.random.rand(net_size) - 1
    tau_m = 20e-3 * np.ones(net_size)
    tau_s = 20e-3 * np.ones(net_size)

    # - Layer generation
    fl0 = RecLIFCurrentInJax(
        w_recurrent=w_recurrent,
        bias=bias,
        noise_std=0.1,
        tau_mem=tau_m,
        tau_syn=tau_s,
        dt=dt,
    )

    # - Input signal
    tsInCont = TSContinuous(times=np.arange(100),
                            samples=np.ones((100, net_size)))

    # - Compare states and time before and after
    vStateBefore = np.copy(fl0.state["Vmem"])
    ts_output = fl0.evolve(tsInCont, duration=0.1)
    assert fl0.t == 0.1
    assert (vStateBefore != fl0.state["Vmem"]).any()

    # - Test TS only evolution
    fl0.reset_all()
    ts_output = fl0.evolve(tsInCont)
    assert fl0.t == 99

    fl0.reset_all()
    assert fl0.t == 0
    assert (vStateBefore == fl0.state["Vmem"]).all()

    # - Test that some errors are caught
    with pytest.raises(AssertionError):
        fl1 = RecLIFCurrentInJax(
            w_recurrent=np.zeros((3, 2)),
            tau_mem=np.zeros(3),
            tau_syn=np.zeros(3),
            bias=np.zeros(3),
        )

    with pytest.raises(AssertionError):
        fl1 = RecLIFCurrentInJax(
            w_recurrent=np.zeros((2, 2)),
            tau_mem=np.zeros(3),
            tau_syn=np.zeros(3),
            bias=np.zeros(3),
        )

    with pytest.raises(AssertionError):
        fl1 = RecLIFCurrentInJax(
            w_recurrent=np.zeros((2, 2)),
            tau_mem=np.zeros(2),
            tau_syn=np.zeros(3),
            bias=np.zeros(3),
        )
Exemplo n.º 19
0
def test_setWeightsIn():
    """ Test weight setting"""
    from rockpool.layers import FFIAFNest, RecIAFSpkInNest, RecAEIFSpkInNest
    from rockpool import TSEvent, TSContinuous
    import numpy as np

    # - Generic parameters
    weights_in = np.array([[-0.5, 0.02, 0.4], [0.2, -0.3, -0.15]])
    weights_rec = np.random.rand(3, 3) * 0.01
    bias = 0.01
    tau_mem = [0.02, 0.05, 0.1]
    tau_syn_exc = [0.2, 0.01, 0.01]
    tau_syn_inh = tau_syn_exc

    # - Different input weights for initialization of fl1
    weights_in1 = np.array([[-0.1, 0.02, 0.4], [0.2, -0.3, -0.15]])

    # - Input time series
    tsInpCont = TSContinuous(np.arange(15) * 0.01, np.ones(15) * 0.1)
    tsInp = TSEvent([0.1], [0])

    ## -- FFIAFNEst
    # - Layer generation
    fl0 = FFIAFNest(
        weights=weights_in, dt=0.001, bias=bias, tau_mem=tau_mem, refractory=0.001
    )
    fl1 = FFIAFNest(
        weights=weights_in1, dt=0.001, bias=bias, tau_mem=tau_mem, refractory=0.001
    )

    # - Set input weights to same as fl0
    fl1.weights = weights_in
    assert (fl1.weights_ == weights_in).all()

    # - Compare states before and after
    fl0.evolve(tsInpCont, duration=0.12)
    fl1.evolve(tsInpCont, duration=0.12)

    assert (fl0.state == fl1.state).all()

    fl0.terminate()
    fl1.terminate()

    ## -- RecIAFSpkInNest
    # - Layer generation
    fl0 = RecIAFSpkInNest(
        weights_in=weights_in,
        weights_rec=weights_rec,
        dt=0.001,
        bias=bias,
        tau_mem=tau_mem,
        tau_syn_exc=tau_syn_exc,
        tau_syn_inh=tau_syn_inh,
        refractory=0.001,
        record=True,
    )
    fl1 = RecIAFSpkInNest(
        weights_in=weights_in1,
        weights_rec=weights_rec,
        dt=0.001,
        bias=bias,
        tau_mem=tau_mem,
        tau_syn_exc=tau_syn_exc,
        tau_syn_inh=tau_syn_inh,
        refractory=0.001,
        record=True,
    )

    # - Set input weights to same as fl0
    fl1.weights_in = weights_in
    assert (fl1.weights_in_ == weights_in).all()

    # - Compare states before and after
    fl0.evolve(tsInp, duration=0.12)
    fl1.evolve(tsInp, duration=0.12)

    assert (fl0.state == fl1.state).all()

    fl0.terminate()
    fl1.terminate()

    ## -- RecAEIFSpkInNest
    # - Layer generation
    fl0 = RecAEIFSpkInNest(
        weights_in=weights_in,
        weights_rec=weights_rec,
        dt=0.001,
        bias=bias,
        tau_mem=tau_mem,
        tau_syn_exc=tau_syn_exc,
        tau_syn_inh=tau_syn_inh,
        refractory=0.001,
        record=True,
    )
    fl1 = RecAEIFSpkInNest(
        weights_in=weights_in1,
        weights_rec=weights_rec,
        dt=0.001,
        bias=bias,
        tau_mem=tau_mem,
        tau_syn_exc=tau_syn_exc,
        tau_syn_inh=tau_syn_inh,
        refractory=0.001,
        record=True,
    )

    # - Set input weights to same as fl0
    fl1.weights_in[0, 0] = weights_in[0, 0]
    assert (fl1.weights_in_ == weights_in).all()

    # - Compare states before and after
    fl0.evolve(tsInp, duration=0.12)
    fl1.evolve(tsInp, duration=0.12)

    assert (fl0.state == fl1.state).all()

    fl0.terminate()
    fl1.terminate()
Exemplo n.º 20
0
def test_ff_rate_euler_train():
    """Test ridge regression for FFRateEuler"""
    from rockpool import TSContinuous, TSEvent
    from rockpool.layers import FFRateEuler, FFExpSyn, PassThrough

    # - Layers
    size_in = 6
    size = 3
    dt = 0.001
    tau_syn = 0.15

    # - FFExpSyn layer to filter spike trains
    fl_exp_prepare = FFExpSyn(np.eye(size_in), dt=dt, tau_syn=tau_syn)

    # - Spiking input signal

    tDur = 0.01
    nSpikes = 5

    vnC = np.tile(np.arange(size_in),
                  int(np.ceil(1.0 / nSpikes * size)))[:nSpikes]
    vtT = np.linspace(0, tDur, nSpikes, endpoint=False)
    tsIn = TSEvent(vtT, vnC, num_channels=size_in)

    # - Filter signal
    ts_filtered = fl_exp_prepare.evolve(tsIn)

    # - Another FFExpSyn layer to compare training
    fl_exp_train = FFExpSyn(np.zeros((size_in, size)), dt=dt, tau_syn=tau_syn)
    # - Rate layers to be trained
    fl_rate = FFRateEuler(np.zeros((size_in, size)), dt=dt, noise_std=0)
    fl_pt = PassThrough(np.zeros((size_in, size)), dt=dt, noise_std=0)

    # - Target and training
    tgt_samples = np.array([
        np.sin(np.linspace(0, 10 * tDur, int(tDur / dt)) + fPhase)
        for fPhase in np.linspace(0, 3, size)
    ]).T
    ts_tgt = TSContinuous(np.arange(int(tDur / dt)) * dt, tgt_samples)

    fl_exp_train.train_rr(ts_tgt,
                          tsIn,
                          regularize=0.1,
                          is_first=True,
                          is_last=True)
    fl_rate.train_rr(ts_tgt,
                     ts_filtered,
                     regularize=0.1,
                     is_first=True,
                     is_last=True)
    fl_pt.train_rr(ts_tgt,
                   ts_filtered,
                   regularize=0.1,
                   is_first=True,
                   is_last=True)

    assert (
        np.isclose(fl_exp_train.weights, fl_rate.weights, rtol=1e-4,
                   atol=1e-2).all()
        and np.isclose(fl_exp_train.bias, fl_rate.bias, rtol=1e-4,
                       atol=1e-2).all()
        and np.isclose(
            fl_exp_train.weights, fl_pt.weights, rtol=1e-4, atol=1e-2).all()
        and np.isclose(fl_exp_train.bias, fl_pt.bias, rtol=1e-4,
                       atol=1e-2).all()), "Training led to different results"
Exemplo n.º 21
0
def test_continuous_methods():
    from rockpool import TSContinuous

    ts1 = TSContinuous([0, 1, 2], [0, 1, 2])

    # - Interpolation
    assert ts1(0) == 0
    assert ts1(2) == 2
    assert ts1(1.5) == 1.5

    assert ts1._interpolate(0) == 0
    assert ts1._interpolate(2) == 2
    assert ts1._interpolate(1.5) == 1.5

    # - Delay
    ts2 = ts1.delay(1)
    assert ts1.t_start == 0
    assert ts2.t_start == 1

    ts20 = ts1.start_at_zero()
    assert ts20.t_start == 0

    # - Contains
    assert ts1.contains(0)
    assert ~ts1.contains(-1)
    assert ts1.contains([0, 1, 2])
    assert ~ts1.contains([0, 1, 2, 3])

    # - Resample
    ts2 = ts1.resample([0.1, 1.1, 1.9])

    # - Merge
    ts1 = TSContinuous([0, 1, 2], [0, 1, 2])
    ts2 = TSContinuous([0, 1, 2], [1, 2, 3])
    ts3 = ts1.merge(ts2, remove_duplicates=True)
    assert np.size(ts3.samples) == 3
    assert np.size(ts1.samples) == 3
    assert np.size(ts2.samples) == 3

    ts3 = ts1.merge(ts2, remove_duplicates=False)
    assert np.size(ts3.samples) == 6

    # - Append
    ts1 = TSContinuous([0, 1, 2], [0, 1, 2])
    ts2 = TSContinuous([0, 1, 2], [1, 2, 3])
    ts3 = ts1.append_t(ts2)
    assert np.size(ts3.times) == 6

    ts3 = ts1.append_c(ts2)
    assert ts3.num_channels == 2

    # - isempty
    assert ~ts1.isempty()
    assert TSContinuous().isempty()

    # - clip
    ts2 = ts1.clip(0.5, 1.5)

    # - Min / Max
    ts1 = TSContinuous([0, 1, 2], [0, 1, 2])
    assert ts1.min == 0
    assert ts1.max == 2
Exemplo n.º 22
0
def test_continuous_inplace_mutation():
    from rockpool import TSContinuous

    ts1 = TSContinuous([0, 1, 2], [0, 1, 2])

    # - Delay
    ts1.delay(1, inplace=True)
    assert ts1.t_start == 1

    # - Resample
    ts1.resample([0.125, 1.1, 1.9], inplace=True)
    assert ts1.t_start == 0.125

    # - Merge
    ts1 = TSContinuous([0, 1, 2], [0, 1, 2])
    ts2 = TSContinuous([0, 1, 2], [1, 2, 3])
    ts1.merge(ts2, remove_duplicates=True, inplace=True)
    assert np.size(ts1.samples) == 3

    ts3 = ts1.merge(ts2, remove_duplicates=False, inplace=True)
    assert np.size(ts3.samples) == 6

    # - Append
    ts1 = TSContinuous([0, 1, 2], [0, 1, 2])
    ts2 = TSContinuous([0, 1, 2], [1, 2, 3])
    ts1.append_t(ts2, inplace=True)
    assert np.size(ts1.times) == 6

    ts1 = TSContinuous([0, 1, 2], [0, 1, 2])
    ts2 = TSContinuous([0, 1, 2], [1, 2, 3])
    ts1.append_c(ts2, inplace=True)
    assert ts1.num_channels == 2

    # - clip
    ts1.clip(0.5, 1.5, inplace=True)
    assert ts1.t_start == 0.5

    # - Start at 0
    ts1.start_at_zero(inplace=True)
    assert ts1.t_start == 0
Exemplo n.º 23
0
def test_continuous_append_t():
    """
    Test append method of TSEvent
    """
    from rockpool import TSContinuous

    # - Generate a few TSEvent objects
    samples = np.random.randint(10, size=(2, 6))
    empty_series = TSContinuous(t_start=1)
    series_list = []
    series_list.append(
        TSContinuous([1, 2], samples[:2, :2], t_start=-1, t_stop=3))
    series_list.append(TSContinuous([1], samples[0, 2:4], t_start=0, t_stop=2))
    series_list.append(
        TSContinuous([1, 3], samples[:2, 4:], t_start=0, t_stop=3))

    # Appending two series
    appended_fromtwo = series_list[0].append_t(series_list[1])
    assert appended_fromtwo.t_start == -1, "Wrong t_start for appended series."
    assert appended_fromtwo.t_stop == 6, "Wrong t_stop for appended series."
    assert (appended_fromtwo.times == np.array(
        [1, 2, 5])).all(), "Wrong time trace for appended series."
    assert ((appended_fromtwo.samples[:2] == samples[:2, :2]).all() and
            (appended_fromtwo.samples[2:]
             == samples[[0], 2:4])).all(), "Wrong samples for appended series."

    # Appending with empty series
    appended_empty_first = empty_series.append_t(series_list[0])
    assert (appended_empty_first.t_start == empty_series.t_start
            ), "Wrong t_start when appending with empty."
    assert (appended_empty_first.t_stop == series_list[0].duration +
            empty_series.t_start), "Wrong t_stop when appending with empty."
    assert (appended_empty_first.samples == series_list[0].samples
            ).all(), "Wrong samples when appending with empty"
    assert (appended_empty_first.times == series_list[0].times +
            2).all(), "Wrong time trace when appending with empty"

    appended_empty_last = series_list[0].append_t(empty_series)
    assert (appended_empty_last.t_start == series_list[0].t_start
            ), "Wrong t_start when appending with empty."
    assert (
        # - 1 is offset, which is np.median(np.diff(series_list[0].times))=1
        appended_empty_last.t_stop == series_list[0].t_stop +
        empty_series.duration), "Wrong t_stop when appending with empty."
    assert (appended_empty_last.samples == series_list[0].samples
            ).all(), "Wrong samples when appending with empty"
    assert (appended_empty_last.times == series_list[0].times
            ).all(), "Wrong time trace when appending with empty"

    # Appending multiple time series
    appended_fromthree = series_list[0].append_t(series_list[1:], offset=None)
    exptd_offset = np.median(np.diff(series_list[0].times))
    exptd_ts2_delay = exptd_offset + series_list[0].t_stop - series_list[
        1].t_start
    # - No offset between ts2 and ts3 because ts2 has only one element
    exptd_ts3_delay = exptd_ts2_delay + series_list[1].t_stop - series_list[
        2].t_start
    exptd_ts2_times = series_list[1].times + exptd_ts2_delay
    exptd_ts3_times = series_list[2].times + exptd_ts3_delay

    assert (appended_fromthree.times == np.r_[series_list[0].times,
                                              exptd_ts2_times, exptd_ts3_times]
            ).all(), "Wrong time trace when appending from list"
    assert (appended_fromthree.samples == np.vstack([
        series.samples for series in series_list
    ])).all(), "Wrong samples when appending from list"

    # - Generating from list of TSContinuous
    appended_fromlist = TSContinuous.concatenate_t(series_list)

    assert (appended_fromthree.times == appended_fromlist.times
            ).all(), "Wrong time trace when appending from list"
    assert (appended_fromthree.samples == appended_fromlist.samples
            ).all(), "Wrong samples when appending from list"