Пример #1
0
def test_compare_solvers(Simulator, plt, seed, allclose):
    pytest.importorskip("sklearn")

    N = 70
    decoder_solvers = [
        Lstsq(),
        LstsqNoise(),
        LstsqL2(),
        LstsqL2nz(),
        LstsqL1(max_iter=5000),
    ]
    weight_solvers = [LstsqL1(weights=True, max_iter=5000), LstsqDrop(weights=True)]

    tfinal = 4

    def input_function(t):
        return np.interp(t, [1, 3], [-1, 1], left=-1, right=1)

    model = nengo.Network(seed=seed)
    with model:
        u = nengo.Node(output=input_function)
        a = nengo.Ensemble(N, dimensions=1)
        nengo.Connection(u, a)
        ap = nengo.Probe(a)

        probes = []
        names = []
        for solver in decoder_solvers + weight_solvers:
            b = nengo.Ensemble(N, dimensions=1, seed=seed + 1)
            nengo.Connection(a, b, solver=solver)
            probes.append(nengo.Probe(b))
            names.append(
                "%s(%s)" % (type(solver).__name__, "w" if solver.weights else "d")
            )

    with Simulator(model) as sim:
        sim.run(tfinal)
    t = sim.trange()

    # ref = sim.data[up]
    ref = nengo.Lowpass(0.02).filtfilt(sim.data[ap], dt=sim.dt)
    outputs = np.array([sim.data[probe][:, 0] for probe in probes]).T
    outputs_f = nengo.Lowpass(0.02).filtfilt(outputs, dt=sim.dt)

    close = signals_allclose(
        t,
        ref,
        outputs_f,
        atol=0.07,
        rtol=0,
        buf=0.1,
        delay=0.007,
        plt=plt,
        labels=names,
        individual_results=True,
        allclose=allclose,
    )

    for name, c in zip(names, close):
        assert c, "Solver '%s' does not meet tolerances" % name
Пример #2
0
def test_lstsql1_repr():
    pytest.importorskip("sklearn")

    check_init_args(LstsqL1, ["weights", "l1", "l2", "max_iter"])
    check_repr(LstsqL1(weights=True, l1=0.2, l2=0.3, max_iter=4))
    assert (repr(LstsqL1(
        weights=True, l1=0.2, l2=0.3,
        max_iter=4)) == "LstsqL1(weights=True, l1=0.2, l2=0.3, max_iter=4)")
Пример #3
0
def test_compare_solvers(Simulator, plt, seed):

    N = 70
    decoder_solvers = [
        Lstsq(), LstsqNoise(),
        LstsqL2(), LstsqL2nz(),
        LstsqL1()
    ]
    weight_solvers = [LstsqL1(weights=True), LstsqDrop(weights=True)]

    tfinal = 4

    def input_function(t):
        return np.interp(t, [1, 3], [-1, 1], left=-1, right=1)

    model = nengo.Network(seed=seed)
    with model:
        u = nengo.Node(output=input_function)
        a = nengo.Ensemble(N, dimensions=1)
        nengo.Connection(u, a)
        ap = nengo.Probe(a)

        probes = []
        names = []
        for solver in decoder_solvers + weight_solvers:
            b = nengo.Ensemble(N, dimensions=1, seed=seed + 1)
            nengo.Connection(a, b, solver=solver)
            probes.append(nengo.Probe(b))
            names.append(
                "%s(%s)" %
                (solver.__class__.__name__, 'w' if solver.weights else 'd'))

    sim = Simulator(model)
    sim.run(tfinal)
    t = sim.trange()

    # ref = sim.data[up]
    ref = nengo.synapses.filtfilt(sim.data[ap], 0.02, dt=sim.dt)
    outputs = np.array([sim.data[probe][:, 0] for probe in probes]).T
    outputs_f = nengo.synapses.filtfilt(outputs, 0.02, dt=sim.dt)

    close = allclose(t,
                     ref,
                     outputs_f,
                     atol=0.07,
                     rtol=0,
                     buf=0.1,
                     delay=0.007,
                     plt=plt,
                     labels=names,
                     individual_results=True)

    for name, c in zip(names, close):
        assert c, "Solver '%s' does not meet tolerances" % name
Пример #4
0
def test_solvers(Simulator, nl_nodirect):

    N = 100
    decoder_solvers = [
        Lstsq(), LstsqNoise(),
        LstsqL2(), LstsqL2nz(),
        LstsqL1()
    ]
    weight_solvers = [LstsqL1(weights=True), LstsqDrop(weights=True)]

    dt = 1e-3
    tfinal = 4

    def input_function(t):
        return np.interp(t, [1, 3], [-1, 1], left=-1, right=1)

    model = nengo.Network('test_solvers', seed=290)
    with model:
        u = nengo.Node(output=input_function)
        a = nengo.Ensemble(nl_nodirect(N), dimensions=1)
        nengo.Connection(u, a)
        ap = nengo.Probe(a)

        probes = []
        names = []
        for solver in decoder_solvers + weight_solvers:
            b = nengo.Ensemble(nl_nodirect(N), dimensions=1, seed=99)
            nengo.Connection(a, b, solver=solver)
            probes.append(nengo.Probe(b))
            names.append(
                "%s(%s)" %
                (solver.__class__.__name__, 'w' if solver.weights else 'd'))

    sim = Simulator(model, dt=dt)
    sim.run(tfinal)
    t = sim.trange()

    # ref = sim.data[up]
    ref = filtfilt(sim.data[ap], 20)
    outputs = np.array([sim.data[probe] for probe in probes])
    outputs_f = filtfilt(outputs, 20, axis=1)

    close = allclose(t,
                     ref,
                     outputs_f,
                     plotter=Plotter(Simulator, nl_nodirect),
                     filename='test_decoders.test_solvers.pdf',
                     labels=names,
                     atol=0.05,
                     rtol=0,
                     buf=100,
                     delay=7)
    for name, c in zip(names, close):
        assert c, "Solver '%s' does not meet tolerances" % name
Пример #5
0
def test_subsolvers_L1(rng):
    A, B = get_system(m=2000, n=1000, d=10, rng=rng)

    l1 = 1e-4
    with Timer() as t:
        LstsqL1(l1=l1, l2=0)(A, B, rng=rng)
    print(t.duration)
Пример #6
0
def test_subsolvers_L1():
    rng = np.random.RandomState(39408)
    A, B = get_system(m=2000, n=1000, d=10, rng=rng)

    l1 = 1e-4
    with Timer() as t:
        LstsqL1(l1=l1, l2=0)(A, B, rng=rng)
    print(t.duration)
Пример #7
0
def test_subsolvers_L1(rng, logger):
    pytest.importorskip('sklearn')

    A, B = get_system(m=2000, n=1000, d=10, rng=rng)

    l1 = 1e-4
    with Timer() as t:
        LstsqL1(l1=l1, l2=0)(A, B, rng=rng)
    logger.info('duration: %0.3f', t.duration)
Пример #8
0
def test_subsolvers_L1(rng):
    pytest.importorskip("sklearn")

    A, B = get_system(m=2000, n=1000, d=10, rng=rng)

    l1 = 1e-4
    with Timer() as t:
        LstsqL1(l1=l1, l2=0)(A, B, rng=rng)
    logging.info("duration: %0.3f", t.duration)
Пример #9
0
def test_subsolvers_L1(rng):
    pytest.importorskip('sklearn')

    A, B = get_system(m=2000, n=1000, d=10, rng=rng)

    l1 = 1e-4
    with Timer() as t:
        LstsqL1(l1=l1, l2=0)(A, B, rng=rng)
    print(t.duration)
Пример #10
0
def test_subsolvers_L1(rng, allclose):
    pytest.importorskip("sklearn")

    A, B = get_system(m=2000, n=1000, d=10, rng=rng)

    l1 = 1e-4
    with Timer() as t:
        x, info = LstsqL1(l1=l1, l2=0)(A, B, rng=rng)
    logging.info("duration: %0.3f", t.duration)

    Ax = np.dot(A, x)
    assert rms(Ax - B) < 2e-2
    assert allclose(Ax, B, atol=0.2, record_rmse=False)
    assert np.max(info["rmses"]) < 3e-2