Exemple #1
0
def trials_error_plot(dt, prestime, t, ystar, learners, n_test_pre, n_train):
    vsynapse = Alpha(0.01, default_dt=dt)

    t1 = (n_test_pre + n_train)*prestime
    t2 = t1 + 20*prestime

    plt.figure()

    ax = plt.subplot(211)
    # plot test output
    output_plot(dt, t, ystar, learners, tmin=t1, tmax=t2, ax=ax)

    ax = plt.subplot(212)
    esynapse = Alpha(5*prestime, default_dt=dt)
    # yrms = rms(esynapse.filtfilt(ystar), axis=1).mean()
    # for learner in learners:
    #     e = rms(esynapse.filtfilt(learner['e']), axis=1) / yrms
    #     plt.plot(t, e)

    # yrms = esynapse.filtfilt(rms(ystar, axis=1)).mean()
    # for learner in learners:
    #     e = esynapse.filtfilt(rms(learner['e'], axis=1)) / yrms
    #     plt.plot(t, e)

    plt.xlabel('training time [s]')
    plt.ylabel('normalized RMS training error')
    plt.tight_layout()
Exemple #2
0
def output_plot(dt, t, ystar, learners, tmin=None, tmax=None, ax=None):
    ax = plt.gca() if ax is None else ax
    ys = [learner['y'] for learner in learners]

    if tmin is not None or tmax is not None:
        tmin = t[0] if tmin is None else tmin
        tmax = t[-1] if tmax is None else tmax
        tmask = (t >= tmin) & (t <= tmax)
        t = t[tmask]
        ystar = ystar[tmask]
        ys = [y[tmask] for y in ys]

    # dinds = slice(0, 2)
    dinds = list(range(2))
    dstyles = ['-', ':']
    # dstyles = ['-', '-.']

    vsynapse = Alpha(0.01, default_dt=dt)
    ystar = ystar[:, dinds]
    ys = [vsynapse.filtfilt(y[:, dinds]) for y in ys]
    for k, (dind, dstyle) in enumerate(zip(dinds, dstyles)):
        ax.set_color_cycle(None)
        ax.plot(t, ystar[:, k], 'k', linestyle=dstyle)
        for y in ys:
            ax.plot(t, y[:, k], linestyle=dstyle)

    plt.ylim((-1.5, 1.5))
Exemple #3
0
def output1_plot(dt, t, ystar, learner, tmin=None, tmax=None, ax=None):
    ax = plt.gca() if ax is None else ax
    y = learner['y']

    if tmin is not None or tmax is not None:
        tmin = t[0] if tmin is None else tmin
        tmax = t[-1] if tmax is None else tmax
        tmask = (t >= tmin) & (t <= tmax)
        t = t[tmask]
        ystar = ystar[tmask]
        y = y[tmask]

    # dinds = slice(0, 2)
    dinds = list(range(2))

    vsynapse = Alpha(0.01, default_dt=dt)
    ystar = ystar[:, dinds]
    # y = y[:, dinds]
    y = vsynapse.filtfilt(y[:, dinds])

    ax.plot(t, y[:, dinds])
    ax.set_color_cycle(None)
    ax.plot(t, ystar[:, dinds], ':')

    plt.legend(['dim %d' % (i+1) for i in range(len(dinds))], loc='best')

    plt.xlim((tmin, tmax))
    plt.ylim((-1.5, 1.5))

    # plt.xticks((tmin, tmax))
    # ax.set_xticks(ax.get_xticks()[::2])
    ax.get_xaxis().get_major_formatter().set_useOffset(False)
    ax.get_xaxis().get_major_formatter().set_scientific(False)
Exemple #4
0
def test_linear_filter_gradient(Simulator):
    with nengo.Network() as net:
        a = nengo.Node([1])
        b = nengo.Node(size_in=1)
        nengo.Connection(a, b, synapse=Alpha(0.01))
        nengo.Probe(b, synapse=Alpha(0.1))

    with Simulator(net) as sim:
        sim.check_gradients()
Exemple #5
0
def test_alpha_merged(Simulator, seed):
    with nengo.Network() as net:
        u = nengo.Node(
            output=nengo.processes.WhiteSignal(1, high=10, seed=seed))
        p0 = nengo.Probe(u, synapse=Alpha(0.03))
        p1 = nengo.Probe(u, synapse=Alpha(0.1))

    with nengo.Simulator(net) as sim:
        sim.run(1)
        canonical = (sim.data[p0], sim.data[p1])

    with Simulator(net) as sim:
        sim.run(1)
        assert np.allclose(sim.data[p0], canonical[0], atol=5e-5)
        assert np.allclose(sim.data[p1], canonical[1], atol=5e-5)
Exemple #6
0
def error_plot(dt, prestime, t, ystar, learners, tmin=None, tmax=None, ax=None):
    ax = plt.gca() if ax is None else ax
    es = [learner['e'] for learner in learners]

    if tmin is not None or tmax is not None:
        tmin = t[0] if tmin is None else tmin
        tmax = t[-1] if tmax is None else tmax
        tmask = (t >= tmin) & (t <= tmax)
        t = t[tmask]
        es = [e[tmask] for e in es]

    vsynapse = Alpha(0.01, default_dt=dt)

    # esynapse = Alpha(1*prestime, default_dt=dt)
    # esynapse = Alpha(1*prestime, default_dt=dt)
    esynapse = Alpha(5*prestime, default_dt=dt)
    # esynapse = Alpha(20*prestime, default_dt=dt)

    # yrms = rms(ystar, axis=1).mean()
    yrms = rms(vsynapse.filtfilt(ystar), axis=1).mean()
    # yrms = rms(esynapse.filtfilt(ystar), axis=1).mean()
    # print(yrms)
    for e in es:
        # erms = esynapse.filtfilt(rms(e, axis=1) / yrms)
        # erms = rms(esynapse.filtfilt(e), axis=1) / yrms
        # erms = rms(vsynapse.filtfilt(e), axis=1) / yrms
        erms = esynapse.filtfilt(rms(vsynapse.filtfilt(e), axis=1) / yrms)
        plt.plot(t, erms)
Exemple #7
0
def test_alpha(Simulator, seed):
    dt = 1e-3
    tau = 0.03
    num, den = [1], [tau**2, 2 * tau, 1]

    t, x, yhat = run_synapse(Simulator, seed, Alpha(tau), dt=dt)
    y = LinearFilter(num, den).filt(x, dt=dt, y0=0)

    assert allclose(t, y, yhat, delay=dt, atol=5e-5)
Exemple #8
0
def error_layers_plots(dt, t, learners):
    vsynapse = Alpha(0.01, default_dt=dt)

    for learner in [l for l in learners if 'els' in l]:
        plt.figure()
        plt.subplot(211)
        dind = 0

        e = vsynapse.filtfilt(learner['e'])
        els = [vsynapse.filtfilt(el) for el in learner['els']]
        plt.plot(t, e[:, dind])
        [plt.plot(t, el[:, dind]) for el in els]

        plt.subplot(212)
        plt.plot(t, norm(e, axis=1))
        [plt.plot(t, norm(el, axis=1)) for el in els]

    plt.show()
Exemple #9
0
def test_argreprs():
    def check_init_args(cls, args):
        assert getfullargspec(cls.__init__).args[1:] == args

    def check_repr(obj):
        assert eval(repr(obj)) == obj

    check_init_args(LinearFilter, ['num', 'den', 'analog'])
    check_repr(LinearFilter([1, 2], [3, 4]))
    check_repr(LinearFilter([1, 2], [3, 4], analog=False))

    check_init_args(Lowpass, ['tau'])
    check_repr(Lowpass(0.3))

    check_init_args(Alpha, ['tau'])
    check_repr(Alpha(0.3))

    check_init_args(Triangle, ['t'])
    check_repr(Triangle(0.3))
Exemple #10
0
def test_argreprs():
    def check_init_args(cls, args):
        assert getfullargspec(cls.__init__).args[1:] == args

    def check_repr(obj):
        assert eval(repr(obj)) == obj

    check_init_args(LinearFilter, ["num", "den", "analog", "method"])
    check_repr(LinearFilter([1, 2], [3, 4]))
    check_repr(LinearFilter([1, 2], [3, 4], analog=False))

    check_init_args(Lowpass, ["tau"])
    check_repr(Lowpass(0.3))

    check_init_args(Alpha, ["tau"])
    check_repr(Alpha(0.3))

    check_init_args(Triangle, ["t"])
    check_repr(Triangle(0.3))
Exemple #11
0
def trials_error_plot(prestime, t, ystar, learners):
    pdt = 0.01
    vsynapse = Alpha(0.02, default_dt=pdt)

    plt.figure()
    dinds = slice(0, 2)

    plt.subplot(211)
    plt.plot(t, ystar[:, dinds])
    for learner in learners:
        y = vsynapse.filtfilt(learner['y'][:, dinds])
        plt.plot(t, y)
    plt.ylabel('outputs')

    plt.subplot(212)
    esynapse = Alpha(5 * prestime, default_dt=pdt)
    for learner in learners:
        e = norm(esynapse.filtfilt(learner['e']), axis=1)
        plt.plot(t, e)
    plt.ylabel('errors')
Exemple #12
0
        # Case 3: neurons -> ens
        conn = nengo.Connection(
            ens1.neurons,
            ens2,
            transform=np.ones((1, ens1.n_neurons)),
            learning_rule_type={"pes": nengo.PES()},
        )
        nengo.Connection(err, conn.learning_rule["pes"])

    with Simulator(net) as sim:
        sim.run(0.01)


@pytest.mark.parametrize(
    "pre_synapse",
    [0, Lowpass(tau=0.05), Alpha(tau=0.005)])
def test_pes_synapse(Simulator, seed, pre_synapse, allclose):
    rule = PES(pre_synapse=pre_synapse)

    with nengo.Network(seed=seed) as model:
        stim = nengo.Node(output=WhiteSignal(0.5, high=10))
        x = nengo.Ensemble(100, 1)

        nengo.Connection(stim, x, synapse=None)
        conn = nengo.Connection(x, x, learning_rule_type=rule)

        p_neurons = nengo.Probe(x.neurons, synapse=pre_synapse)
        p_pes = nengo.Probe(conn.learning_rule, "activities")

    with Simulator(model) as sim:
        sim.run(0.5)
Exemple #13
0
def test_mergeable():
    # anything is mergeable with an empty list
    assert mergeable(None, [])

    # ops with different numbers of sets/incs/reads/updates are not mergeable
    assert not mergeable(DummyOp(sets=[DummySignal()]), [DummyOp()])
    assert not mergeable(DummyOp(incs=[DummySignal()]), [DummyOp()])
    assert not mergeable(DummyOp(reads=[DummySignal()]), [DummyOp()])
    assert not mergeable(DummyOp(updates=[DummySignal()]), [DummyOp()])
    assert mergeable(DummyOp(sets=[DummySignal()]),
                     [DummyOp(sets=[DummySignal()])])

    # check matching dtypes
    assert not mergeable(DummyOp(sets=[DummySignal(dtype=np.float32)]),
                         [DummyOp(sets=[DummySignal(dtype=np.float64)])])

    # shape mismatch
    assert not mergeable(DummyOp(sets=[DummySignal(shape=(1, 2))]),
                         [DummyOp(sets=[DummySignal(shape=(1, 3))])])

    # display shape mismatch
    assert not mergeable(
        DummyOp(sets=[DummySignal(base_shape=(2, 2), shape=(4, 1))]),
        [DummyOp(sets=[DummySignal(base_shape=(2, 2), shape=(1, 4))])])

    # first dimension mismatch
    assert mergeable(DummyOp(sets=[DummySignal(shape=(3, 2))]),
                     [DummyOp(sets=[DummySignal(shape=(4, 2))])])

    # Copy (inc must match)
    assert mergeable(Copy(DummySignal(), DummySignal(), inc=True),
                     [Copy(DummySignal(), DummySignal(), inc=True)])
    assert not mergeable(Copy(DummySignal(), DummySignal(), inc=True),
                         [Copy(DummySignal(), DummySignal(), inc=False)])

    # elementwise (first dimension must match)
    assert mergeable(
        ElementwiseInc(DummySignal(), DummySignal(), DummySignal()),
        [ElementwiseInc(DummySignal(), DummySignal(), DummySignal())])
    assert mergeable(
        ElementwiseInc(DummySignal(shape=(1,)), DummySignal(), DummySignal()),
        [ElementwiseInc(DummySignal(shape=()), DummySignal(), DummySignal())])
    assert not mergeable(
        ElementwiseInc(DummySignal(shape=(3,)), DummySignal(), DummySignal()),
        [ElementwiseInc(DummySignal(shape=(2,)), DummySignal(),
                        DummySignal())])

    # simpyfunc (t input must match)
    time = DummySignal()
    assert mergeable(SimPyFunc(None, None, time, None),
                     [SimPyFunc(None, None, time, None)])
    assert mergeable(SimPyFunc(None, None, None, DummySignal()),
                     [SimPyFunc(None, None, None, DummySignal())])
    assert not mergeable(SimPyFunc(None, None, DummySignal(), None),
                         [SimPyFunc(None, None, None, DummySignal())])

    # simneurons
    # check matching TF_NEURON_IMPL
    assert mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()),
                     [SimNeurons(LIF(), DummySignal(), DummySignal())])
    assert not mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()),
                         [SimNeurons(LIFRate(), DummySignal(), DummySignal())])

    # check custom with non-custom implementation
    assert not mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()),
                         [SimNeurons(Izhikevich(), DummySignal(),
                                     DummySignal())])

    # check non-custom matching
    assert not mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal()),
        [SimNeurons(AdaptiveLIF(), DummySignal(), DummySignal())])
    assert not mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                   states=[DummySignal(dtype=np.float32)]),
        [SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                    states=[DummySignal(dtype=np.int32)])])
    assert mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                   states=[DummySignal(shape=(3,))]),
        [SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                    states=[DummySignal(shape=(2,))])])
    assert not mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                   states=[DummySignal(shape=(2, 1))]),
        [SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                    states=[DummySignal(shape=(2, 2))])])

    # simprocess
    # mode must match
    assert not mergeable(
        SimProcess(Lowpass(0), None, None, DummySignal(), mode="inc"),
        [SimProcess(Lowpass(0), None, None, DummySignal(), mode="set")])

    # check matching TF_PROCESS_IMPL
    # note: we only have one item in TF_PROCESS_IMPL at the moment, so no
    # such thing as a mismatch
    assert mergeable(SimProcess(Lowpass(0), None, None, DummySignal()),
                     [SimProcess(Lowpass(0), None, None, DummySignal())])

    # check custom vs non custom
    assert not mergeable(SimProcess(Lowpass(0), None, None, DummySignal()),
                         [SimProcess(Alpha(0), None, None, DummySignal())])

    # check non-custom matching
    assert mergeable(SimProcess(Triangle(0), None, None, DummySignal()),
                     [SimProcess(Alpha(0), None, None, DummySignal())])

    # simtensornode
    a = SimTensorNode(None, DummySignal(), None, DummySignal())
    assert not mergeable(a, [a])

    # learning rules
    a = SimBCM(DummySignal((4,)), DummySignal(), DummySignal(), DummySignal(),
               DummySignal())
    b = SimBCM(DummySignal((5,)), DummySignal(), DummySignal(), DummySignal(),
               DummySignal())
    assert not mergeable(a, [b])
Exemple #14
0
# --- plots
n_batches = len(learners[0].batch_errors)
n_hids = len(dhids)
batch_inds = (n_per_batch / 1000.) * np.arange(n_batches)
# epoch_inds = (trainX.shape[0] / 1000.) * np.arange(1, epochs+1)

fig = plt.figure()
rows, cols = 5, 1
# layer_styles = ('-', '-.', ':', '--')
layer_styles = ('-', '--', ':')

# filt = Alpha(10, default_dt=n_per_batch)
# filt = Alpha(20, default_dt=n_per_batch)
# filt = Alpha(30, default_dt=n_per_batch)
filt = Alpha(100, default_dt=n_per_batch)

ax = fig.add_subplot(rows, cols, 1)
ax.set_yscale('log')
batch_errors = np.array([learner.batch_errors for learner in learners])
ax.plot(batch_inds, filt.filtfilt(batch_errors, axis=1).T)

ax = fig.add_subplot(rows, cols, 2)
ax.set_yscale('log')
delta_norms = np.array([
    x.delta_norms if x.delta_norms else np.nan * np.ones((n_batches, n_hids))
    for x in learners
])
delta_norms = filt.filtfilt(delta_norms, axis=1)
for i in range(n_hids):
    ax.set_color_cycle(None)
Exemple #15
0
    if len(learner.batch_errors) == 0:
        learner.train(epochs, batch_fn, test_set=test_set)

for learner in learners_n:
    print(", ".join("||W%d|| = %0.3f" % (i, norm(w))
                    for i, w in enumerate(learner.network.weights)))

# --- plot results (cols=[unnormalized, normalized], traces=learners)
learner_groups = (learners_u, learners_n)
learner_group_names = ('Unnormalized B', 'Normalized B')
cols = len(learner_groups)
rows = 1

plt.figure(figsize=(6.4, 4))

filt = Alpha(30, default_dt=n_per_batch)

for i, (learners, name) in enumerate(zip(learner_groups, learner_group_names)):
    ax = plt.subplot(rows, cols, i + 1)

    for learner in learners:
        x = learner.batch_errors / Yrms
        y = filt.filtfilt(x) if len(x) > 0 else []
        batch_inds = n_per_batch * np.arange(len(x))
        ax.semilogy(batch_inds, y, label=learner.name)

    plt.ylim([1e-1, 1.2e0])
    plt.xlabel('# of examples')
    plt.ylabel('normalized RMS error')
    plt.legend(loc=1)
    plt.title(name)
Exemple #16
0
def cosyne_plot(dt, prestime, t, ystar, learners, n_test_pre, n_train,
                offline_data=None):
    vsynapse = Alpha(0.01, default_dt=dt)

    learner = learners[-1]

    n_show_pre = 10
    n_show_post = 10
    assert n_show_pre <= n_test_pre

    t0 = n_test_pre*prestime
    t1 = t0 + n_train*prestime
    # t2 = t1 + n_test_post*prestime

    tpre0 = 0
    tpre1 = tpre0 + n_show_pre*prestime
    tpost0 = t1
    # tpost0 = t1 + 10*prestime
    tpost1 = tpost0 + n_show_post*prestime

    plt.figure(figsize=(6.4, 7))

    # subplot_shape = (2, 3)
    subplot_shape = (3, 3)

    ax = plt.subplot2grid(subplot_shape, (0, 2))
    # output_plot(dt, t, ystar, learners, tmin=tpre0, tmax=tpre1, ax=ax)
    output1_plot(dt, t, ystar, learner, tmin=tpre0, tmax=tpre1, ax=ax)
    plt.title('Pre-learning output')

    ax = plt.subplot2grid(subplot_shape, (1, 2))
    # output_plot(dt, t, ystar, learners, tmin=tpost0, tmax=tpost1, ax=ax)
    output1_plot(dt, t, ystar, learner, tmin=tpost0, tmax=tpost1, ax=ax)
    plt.xlabel('simulation time [s]')
    plt.title('Post-learning output')

    hid_spikes0 = learner['hs'][0][:, :40]
    hid_spikes1 = learner['hs'][-1][:, :40]

    ax = plt.subplot2grid(subplot_shape, (0, 0))
    spike_plot(dt, t, hid_spikes0, tmin=tpre0, tmax=tpre1, ax=ax)
    plt.title('Pre-learning spikes 1')

    ax = plt.subplot2grid(subplot_shape, (1, 0))
    spike_plot(dt, t, hid_spikes0, tmin=tpost0, tmax=tpost1, ax=ax)
    plt.xlabel('simulation time [s]')
    plt.title('Post-learning spikes 1')

    ax = plt.subplot2grid(subplot_shape, (0, 1))
    spike_plot(dt, t, hid_spikes1, tmin=tpre0, tmax=tpre1, ax=ax)
    plt.title('Pre-learning spikes 2')

    ax = plt.subplot2grid(subplot_shape, (1, 1))
    spike_plot(dt, t, hid_spikes1, tmin=tpost0, tmax=tpost1, ax=ax)
    plt.xlabel('simulation time [s]')
    plt.title('Post-learning spikes 2')

    ax = plt.subplot2grid(subplot_shape, (2, 0), colspan=3)
    error_plot(dt, prestime, t, ystar, [learner], tmin=t0, tmax=t1, ax=ax)
    if offline_data:
        Yrms = rms(offline_data['Y'], axis=1).mean()
        eo = offline_data['learners'][-1]['batch_errors'] / Yrms
        dto = prestime*offline_data['n_per_batch']
        to = t0 + dto*np.arange(len(eo))

        tmask = (to >= t0) & (to <= t1)
        to = to[tmask]
        eo = eo[tmask]

        esynapse = Alpha(5*prestime, default_dt=dto)
        eo = esynapse.filtfilt(eo)

        plt.plot(to, eo, 'k:')
        plt.xlim((t0, t1))
        plt.legend(['spiking', 'non-spiking'], loc='best')

    plt.xlabel('simulation time [s]')
    plt.ylabel('normalized RMS error')
    plt.title('Error')

    plt.tight_layout()
Exemple #17
0
def plot_batches(x, label=None):
    filt = Alpha(200, default_dt=n_per_batch)
    y = filt.filtfilt(x) if len(x) > 0 else []
    batch_inds = n_per_batch * np.arange(len(x))
    plt.semilogy(batch_inds, y, label=label)
# x = np.linspace(-1, 1, 10001)
# for i, (f, df) in enumerate(f_dfs):
#     f_df_labels[i] += ' (max %0.1f)' % (df(x).max() / amp)

# --- plot results (rows=[train, test], cols=learners)
n_trials = batch_errors.shape[0]
n_fdfs = batch_errors.shape[1]
n_learners = batch_errors.shape[2]
assert n_learners == len(learner_names)

rows = 2
cols = n_learners

plt.figure(figsize=(7, 6))

filt = Alpha(3000, default_dt=n_per_batch)
# filt = Alpha(10000, default_dt=n_per_batch)
for col in range(cols):
    ax = plt.subplot(rows, cols, col + 1)

    # error = filt.filt(batch_errors[:, :, col, :], axis=-1)
    error = filt.filtfilt(batch_errors[:, :, col, :], axis=-1)
    # batch_inds = n_per_batch * np.arange(error.shape[-1])
    batch_inds = (n_per_batch / 1000.) * np.arange(error.shape[-1])
    error, batch_inds = error[..., ::10], batch_inds[::10]

    sns.tsplot(
        data=np.transpose(error, (0, 2, 1)),
        time=batch_inds,
        condition=f_df_labels,
        # err_style='unit_traces',
Exemple #19
0
# --- plot results (cols=[train, test], traces=learners)
fig = plt.figure(figsize=(6.4, 5))
rows = 2
cols = 2

n_batches = len(learners[0].batch_errors)
batch_inds = (n_per_batch / 1000.) * np.arange(n_batches)
epoch_inds = (trainX.shape[0] / 1000.) * np.arange(1, epochs+1)

# - train subplot
ax = fig.add_subplot(rows, cols, 1)

# filt = Alpha(1000, default_dt=n_per_batch)
# filt = Alpha(3000, default_dt=n_per_batch)
filt = Alpha(10000, default_dt=n_per_batch)

for learner in learners:
    y = filt.filtfilt(learner.batch_errors)
    ax.semilogy(batch_inds, y, label=learner.name)

# plt.ylim([1e-4, 5e-1])
plt.ylim([None, 2e-1])
plt.xlabel('thousands of examples')
plt.ylabel('train error')
plt.legend(loc='best')
# plt.title("Train error")

# - test subplot
ax = fig.add_subplot(rows, cols, 2)
Exemple #20
0
#     # plt.plot(Yshow[:, 0], Yshow[:, 1], 'k.')

#     for i, (label, color) in enumerate(zip(f_df_labels, f_df_colors)):
#         learner = results[i][col]
#         Zshow = learner.network.predict(Xshow)
#         plot_error_line(Yshow, Zshow, c=color)
#         # plt.plot(Zshow[:, 0], Zshow[:, 1], '.')

# --- plot results (rows=[train, test], cols=learners)
rows = 2
cols = len(results[0])

plt.figure(figsize=(7, 6))

# filt = Alpha(3000, default_dt=n_per_batch)
filt = Alpha(10000, default_dt=n_per_batch)
for col in range(cols):
    ax = plt.subplot(rows, cols, col + 1)

    for i, label in enumerate(f_df_labels):
        learner = results[i][col]
        x = learner.batch_errors
        # y = x
        y = filt.filtfilt(x) if len(x) > 0 else []
        batch_inds = n_per_batch * np.arange(len(x))
        ax.semilogy(batch_inds, y, label=label)

    plt.ylim([5e-3, 2e-1])
    ax.set_xticklabels([])
    if col == 0:
        plt.ylabel('train error')
Exemple #21
0
    # learners = [bp_learner, fas_learner]
    # learners = [bp_learner, fa_learner]
    # learners = [bp_learner, fa_learner, fal_learner, fas_learner]

    for learner in learners:
        learner.train(epochs, batch_fn, test_set=test_set)

    for learner in learners:
        print(", ".join("||W%d|| = %0.3f" % (i, norm(w))
                        for i, w in enumerate(learner.network.weights)))

# --- plots
fig = plt.figure()
rows, cols = 4, 1

filt = Alpha(100, default_dt=n_per_batch)

ax = fig.add_subplot(rows, cols, 1)
ax.set_yscale('log')
for learner in learners:
    ax.plot(filt.filtfilt(learner.batch_errors), label=learner.name)

ax = fig.add_subplot(rows, cols, 2)
ax.set_yscale('log')
for learner in learners:
    if learner.delta_norms is not None:
        ax.plot(filt.filtfilt(learner.delta_norms), label=learner.name)

ax = fig.add_subplot(rows, cols, 3)
for learner in learners:
    if getattr(learner, 'bp_angles', None) is not None:
Exemple #22
0
learner_names = [learner.name for learner in learners]
# trial_errors = np.array(trial_errors)
eta_errors = np.array(eta_errors)
batch_inds = n_per_batch * np.arange(eta_errors.shape[-1])

# --- plot results (traces=learners)
plt.figure(figsize=(6.35, 6.0))

rows, cols = 2, 2
assert len(etas) <= rows * cols

for i, eta in enumerate(etas):
    trial_errors = eta_errors[i]

    filt = Alpha(30, default_dt=n_per_batch)
    trial_errors[~np.isfinite(trial_errors)] = 1e6
    trial_errors = trial_errors.clip(None, 1e6)
    trial_errors = filt.filt(trial_errors / Yrms, axis=-1)
    # trial_errors = filt.filtfilt(trial_errors / Yrms, axis=-1)

    ax = plt.subplot(rows, cols, i + 1)
    # ax = sns.tsplot(data=np.transpose(trial_errors, (0, 2, 1)),
    #                 time=batch_inds, condition=learner_names)
    sns.tsplot(data=np.transpose(trial_errors, (0, 2, 1)),
               time=batch_inds,
               condition=learner_names,
               err_style='unit_traces',
               legend=(i == 0))

    # ax.set(yscale='log')
x = np.linspace(-1, 1, 10001)
for i, (f, df) in enumerate(f_dfs):
    f_df_labels[i] += ' (max %0.1f)' % (df(x).max() / amp)

# --- plot results
errors0 = list(errors.values())[0]
n_trials = len(errors0)
n_fdfs = errors0[0].shape[0]
assert len(f_dfs) == len(f_df_labels) == n_fdfs
assert len(learner_names) == errors0[0].shape[1]

rows = len(etas)
cols = len(learner_names)

# filt = Alpha(30, default_dt=n_per_batch)
filt = Alpha(80, default_dt=n_per_show)
# filt = Alpha(100, default_dt=n_per_show)

plt.figure(figsize=(6, 8))
for row in range(rows):  # for each eta
    for col in range(cols):  # for each learner
        print("Plotting (%d, %d)" % (row, col))
        ax = plt.subplot(rows, cols, row * cols + col + 1)
        eta = etas[row]

        # (n_trials x n_fdfs x n) matrix of errors
        # print(type(errors[eta][0]))
        # print(errors[eta][0].shape)
        error = np.array(
            [[errors[eta][itrial][ifdf, col] for ifdf in range(n_fdfs)]
             for itrial in range(n_trials)])
Exemple #24
0
 def plot_batches(x, label=None, color=None):
     filt = Alpha(10, default_dt=n_per_batch)
     y = filt.filtfilt(x) / Yrms
     batch_inds = n_per_batch * np.arange(len(x))
     plt.plot(batch_inds, y, label=label, color=color)
def test_mergeable():
    # anything is mergeable with an empty list
    assert mergeable(None, [])

    # ops with different numbers of sets/incs/reads/updates are not mergeable
    assert not mergeable(dummies.Op(sets=[dummies.Signal()]), [dummies.Op()])
    assert not mergeable(dummies.Op(incs=[dummies.Signal()]), [dummies.Op()])
    assert not mergeable(dummies.Op(reads=[dummies.Signal()]), [dummies.Op()])
    assert not mergeable(dummies.Op(updates=[dummies.Signal()]), [dummies.Op()])
    assert mergeable(dummies.Op(sets=[dummies.Signal()]),
                     [dummies.Op(sets=[dummies.Signal()])])

    # check matching dtypes
    assert not mergeable(dummies.Op(sets=[dummies.Signal(dtype=np.float32)]),
                         [dummies.Op(sets=[dummies.Signal(dtype=np.float64)])])

    # shape mismatch
    assert not mergeable(dummies.Op(sets=[dummies.Signal(shape=(1, 2))]),
                         [dummies.Op(sets=[dummies.Signal(shape=(1, 3))])])

    # display shape mismatch
    assert not mergeable(
        dummies.Op(sets=[dummies.Signal(base_shape=(2, 2), shape=(4, 1))]),
        [dummies.Op(sets=[dummies.Signal(base_shape=(2, 2), shape=(1, 4))])])

    # first dimension mismatch
    assert mergeable(dummies.Op(sets=[dummies.Signal(shape=(3, 2))]),
                     [dummies.Op(sets=[dummies.Signal(shape=(4, 2))])])

    # Copy (inc must match)
    assert mergeable(Copy(dummies.Signal(), dummies.Signal(), inc=True),
                     [Copy(dummies.Signal(), dummies.Signal(), inc=True)])
    assert not mergeable(Copy(dummies.Signal(), dummies.Signal(), inc=True),
                         [Copy(dummies.Signal(), dummies.Signal(), inc=False)])

    # elementwise (first dimension must match)
    assert mergeable(
        ElementwiseInc(dummies.Signal(), dummies.Signal(), dummies.Signal()),
        [ElementwiseInc(dummies.Signal(), dummies.Signal(), dummies.Signal())])
    assert mergeable(
        ElementwiseInc(dummies.Signal(shape=(1,)), dummies.Signal(), dummies.Signal()),
        [ElementwiseInc(dummies.Signal(shape=()), dummies.Signal(), dummies.Signal())])
    assert not mergeable(
        ElementwiseInc(dummies.Signal(shape=(3,)), dummies.Signal(), dummies.Signal()),
        [ElementwiseInc(dummies.Signal(shape=(2,)), dummies.Signal(),
                        dummies.Signal())])

    # simpyfunc (t input must match)
    time = dummies.Signal()
    assert mergeable(SimPyFunc(None, None, time, None),
                     [SimPyFunc(None, None, time, None)])
    assert mergeable(SimPyFunc(None, None, None, dummies.Signal()),
                     [SimPyFunc(None, None, None, dummies.Signal())])
    assert not mergeable(SimPyFunc(None, None, dummies.Signal(), None),
                         [SimPyFunc(None, None, None, dummies.Signal())])

    # simneurons
    # check matching TF_NEURON_IMPL
    assert mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()),
                     [SimNeurons(LIF(), dummies.Signal(), dummies.Signal())])
    assert not mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()),
                         [SimNeurons(LIFRate(), dummies.Signal(), dummies.Signal())])

    # check custom with non-custom implementation
    assert not mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()),
                         [SimNeurons(Izhikevich(), dummies.Signal(),
                                     dummies.Signal())])

    # check non-custom matching
    assert not mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal()),
        [SimNeurons(AdaptiveLIF(), dummies.Signal(), dummies.Signal())])
    assert not mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                   states=[dummies.Signal(dtype=np.float32)]),
        [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                    states=[dummies.Signal(dtype=np.int32)])])
    assert mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                   states=[dummies.Signal(shape=(3,))]),
        [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                    states=[dummies.Signal(shape=(2,))])])
    assert not mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                   states=[dummies.Signal(shape=(2, 1))]),
        [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                    states=[dummies.Signal(shape=(2, 2))])])

    # simprocess
    # mode must match
    assert not mergeable(
        SimProcess(Lowpass(0), None, dummies.Signal(), dummies.Signal(),
                   mode="inc"),
        [SimProcess(Lowpass(0), None, dummies.Signal(), dummies.Signal(),
                    mode="set")])

    # check that lowpass match
    assert mergeable(SimProcess(Lowpass(0), None, None, dummies.Signal()),
                     [SimProcess(Lowpass(0), None, None, dummies.Signal())])

    # check that lowpass and linear don't match
    assert not mergeable(SimProcess(Lowpass(0), None, None, dummies.Signal()),
                         [SimProcess(Alpha(0), None, None, dummies.Signal())])

    # check that two linear do match
    assert mergeable(
        SimProcess(Alpha(0.1), dummies.Signal(), None, dummies.Signal()),
        [SimProcess(LinearFilter([1], [1, 1, 1]), dummies.Signal(), None,
                    dummies.Signal())])

    # check custom and non-custom don't match
    assert not mergeable(SimProcess(Triangle(0), None, None, dummies.Signal()),
                         [SimProcess(Alpha(0), None, None, dummies.Signal())])

    # check non-custom matching
    assert mergeable(SimProcess(Triangle(0), None, None, dummies.Signal()),
                     [SimProcess(Triangle(0), None, None, dummies.Signal())])

    # simtensornode
    a = SimTensorNode(None, dummies.Signal(), None, dummies.Signal())
    assert not mergeable(a, [a])

    # learning rules
    a = SimBCM(dummies.Signal((4,)), dummies.Signal(), dummies.Signal(), dummies.Signal(),
               dummies.Signal())
    b = SimBCM(dummies.Signal((5,)), dummies.Signal(), dummies.Signal(), dummies.Signal(),
               dummies.Signal())
    assert not mergeable(a, [b])
        )
        nengo.Connection(err, conn.learning_rule["pes"])
        # Case 3: neurons -> ens
        conn = nengo.Connection(
            ens1.neurons,
            ens2,
            transform=np.ones((1, ens1.n_neurons)),
            learning_rule_type={"pes": nengo.PES()},
        )
        nengo.Connection(err, conn.learning_rule["pes"])

    with Simulator(net) as sim:
        sim.run(0.01)


@pytest.mark.parametrize("pre_synapse", [0, Lowpass(tau=0.05), Alpha(tau=0.005)])
def test_pes_synapse(Simulator, seed, pre_synapse, allclose):
    rule = PES(pre_synapse=pre_synapse)

    with nengo.Network(seed=seed) as model:
        stim = nengo.Node(output=WhiteSignal(0.5, high=10))
        x = nengo.Ensemble(100, 1)

        nengo.Connection(stim, x, synapse=None)
        conn = nengo.Connection(x, x, learning_rule_type=rule)

        p_neurons = nengo.Probe(x.neurons, synapse=pre_synapse)
        p_pes = nengo.Probe(conn.learning_rule, "activities")

    with Simulator(model) as sim:
        sim.run(0.5)
        p_linear0 = nengo.Probe(inp, synapse=LinearFilter([1], [tau, 1]))
        p_linear1 = nengo.Probe(inp, synapse=LinearFilter([1], [tau * 2, 1]))

    with Simulator(net) as sim:
        sim.run(0.1)

        assert np.allclose(sim.data[p_lowpass0], sim.data[p_linear0])
        assert np.allclose(sim.data[p_lowpass1], sim.data[p_linear1])


@pytest.mark.parametrize(
    "synapse",
    (
        LinearFilter([0.1], [1], analog=False),  # NoX
        LinearFilter([1], [0.1, 1]),  # OneX
        Alpha(0.1),  # NoD
        LinearFilter(
            [0.0004166, 0.0016664, 0.0024996, 0.0016664, 0.0004166],
            [1.0, -3.18063855, 3.86119435, -2.11215536, 0.43826514],
        ),  # General
    ),
)
def test_linearfilter_minibatched(Simulator, synapse):
    n_steps = 10
    mini_size = 4
    signal_d = 3

    with nengo.Network() as net:
        inp = nengo.Node(np.zeros(signal_d))

        p0 = nengo.Probe(inp, synapse=synapse)