Esempio n. 1
0
def test_solvers():
    check_init_args(Lstsq, ["weights", "rcond"])
    check_repr(Lstsq(weights=True, rcond=0.1))
    assert repr(Lstsq(weights=True,
                      rcond=0.1)) == "Lstsq(weights=True, rcond=0.1)"

    check_init_args(LstsqNoise, ["weights", "noise", "solver"])
    check_repr(LstsqNoise(weights=True, noise=0.2))
    assert (repr(LstsqNoise(
        weights=True, noise=0.2)) == "LstsqNoise(weights=True, noise=0.2)")

    check_init_args(LstsqL2, ["weights", "reg", "solver"])
    check_repr(LstsqL2(weights=True, reg=0.2))
    assert repr(LstsqL2(weights=True,
                        reg=0.2)) == "LstsqL2(weights=True, reg=0.2)"

    check_init_args(LstsqL2nz, ["weights", "reg", "solver"])
    check_repr(LstsqL2nz(weights=True, reg=0.2))
    assert repr(LstsqL2nz(weights=True,
                          reg=0.2)) == "LstsqL2nz(weights=True, reg=0.2)"

    check_init_args(NoSolver, ["values", "weights"])
    check_repr(NoSolver(values=np.array([[1.2, 3.4, 5.6, 7.8]]), weights=True))
    assert (repr(NoSolver([[1.2, 3.4, 5.6, 7.8]], weights=True)) ==
            "NoSolver(values=array([[1.2, 3.4, 5.6, 7.8]]), weights=True)")
Esempio n. 2
0
def test_compare_solvers(Simulator, plt, seed, allclose):
    pytest.importorskip("sklearn")

    N = 70
    decoder_solvers = [
        Lstsq(),
        LstsqNoise(),
        LstsqL2(),
        LstsqL2nz(),
        LstsqL1(max_iter=5000),
    ]
    weight_solvers = [LstsqL1(weights=True, max_iter=5000), LstsqDrop(weights=True)]

    tfinal = 4

    def input_function(t):
        return np.interp(t, [1, 3], [-1, 1], left=-1, right=1)

    model = nengo.Network(seed=seed)
    with model:
        u = nengo.Node(output=input_function)
        a = nengo.Ensemble(N, dimensions=1)
        nengo.Connection(u, a)
        ap = nengo.Probe(a)

        probes = []
        names = []
        for solver in decoder_solvers + weight_solvers:
            b = nengo.Ensemble(N, dimensions=1, seed=seed + 1)
            nengo.Connection(a, b, solver=solver)
            probes.append(nengo.Probe(b))
            names.append(
                "%s(%s)" % (type(solver).__name__, "w" if solver.weights else "d")
            )

    with Simulator(model) as sim:
        sim.run(tfinal)
    t = sim.trange()

    # ref = sim.data[up]
    ref = nengo.Lowpass(0.02).filtfilt(sim.data[ap], dt=sim.dt)
    outputs = np.array([sim.data[probe][:, 0] for probe in probes]).T
    outputs_f = nengo.Lowpass(0.02).filtfilt(outputs, dt=sim.dt)

    close = signals_allclose(
        t,
        ref,
        outputs_f,
        atol=0.07,
        rtol=0,
        buf=0.1,
        delay=0.007,
        plt=plt,
        labels=names,
        individual_results=True,
        allclose=allclose,
    )

    for name, c in zip(names, close):
        assert c, "Solver '%s' does not meet tolerances" % name
Esempio n. 3
0
def test_compare_solvers(Simulator, plt, seed):

    N = 70
    decoder_solvers = [
        Lstsq(), LstsqNoise(),
        LstsqL2(), LstsqL2nz(),
        LstsqL1()
    ]
    weight_solvers = [LstsqL1(weights=True), LstsqDrop(weights=True)]

    tfinal = 4

    def input_function(t):
        return np.interp(t, [1, 3], [-1, 1], left=-1, right=1)

    model = nengo.Network(seed=seed)
    with model:
        u = nengo.Node(output=input_function)
        a = nengo.Ensemble(N, dimensions=1)
        nengo.Connection(u, a)
        ap = nengo.Probe(a)

        probes = []
        names = []
        for solver in decoder_solvers + weight_solvers:
            b = nengo.Ensemble(N, dimensions=1, seed=seed + 1)
            nengo.Connection(a, b, solver=solver)
            probes.append(nengo.Probe(b))
            names.append(
                "%s(%s)" %
                (solver.__class__.__name__, 'w' if solver.weights else 'd'))

    sim = Simulator(model)
    sim.run(tfinal)
    t = sim.trange()

    # ref = sim.data[up]
    ref = nengo.synapses.filtfilt(sim.data[ap], 0.02, dt=sim.dt)
    outputs = np.array([sim.data[probe][:, 0] for probe in probes]).T
    outputs_f = nengo.synapses.filtfilt(outputs, 0.02, dt=sim.dt)

    close = allclose(t,
                     ref,
                     outputs_f,
                     atol=0.07,
                     rtol=0,
                     buf=0.1,
                     delay=0.007,
                     plt=plt,
                     labels=names,
                     individual_results=True)

    for name, c in zip(names, close):
        assert c, "Solver '%s' does not meet tolerances" % name
Esempio n. 4
0
def test_solvers(Simulator, nl_nodirect):

    N = 100
    decoder_solvers = [
        Lstsq(), LstsqNoise(),
        LstsqL2(), LstsqL2nz(),
        LstsqL1()
    ]
    weight_solvers = [LstsqL1(weights=True), LstsqDrop(weights=True)]

    dt = 1e-3
    tfinal = 4

    def input_function(t):
        return np.interp(t, [1, 3], [-1, 1], left=-1, right=1)

    model = nengo.Network('test_solvers', seed=290)
    with model:
        u = nengo.Node(output=input_function)
        a = nengo.Ensemble(nl_nodirect(N), dimensions=1)
        nengo.Connection(u, a)
        ap = nengo.Probe(a)

        probes = []
        names = []
        for solver in decoder_solvers + weight_solvers:
            b = nengo.Ensemble(nl_nodirect(N), dimensions=1, seed=99)
            nengo.Connection(a, b, solver=solver)
            probes.append(nengo.Probe(b))
            names.append(
                "%s(%s)" %
                (solver.__class__.__name__, 'w' if solver.weights else 'd'))

    sim = Simulator(model, dt=dt)
    sim.run(tfinal)
    t = sim.trange()

    # ref = sim.data[up]
    ref = filtfilt(sim.data[ap], 20)
    outputs = np.array([sim.data[probe] for probe in probes])
    outputs_f = filtfilt(outputs, 20, axis=1)

    close = allclose(t,
                     ref,
                     outputs_f,
                     plotter=Plotter(Simulator, nl_nodirect),
                     filename='test_decoders.test_solvers.pdf',
                     labels=names,
                     atol=0.05,
                     rtol=0,
                     buf=100,
                     delay=7)
    for name, c in zip(names, close):
        assert c, "Solver '%s' does not meet tolerances" % name